In [None]:
## load data and packages
import cv2
import math, re, os
import tensorflow as tf
import numpy as np
import pandas as pd
from kaggle_datasets import KaggleDatasets
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

## detect TPU'=
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

## Set Variabels
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = KaggleDatasets().get_gcs_path()
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]
CLASSES = np.array(['0', '1', '2', '3', '4'])
OUTPUT_SIZE = CLASSES.size
ONE_HOT_CLASSES = tf.one_hot(CLASSES.astype(np.float), OUTPUT_SIZE)
EPOCHS = 25




## load data set
TRAIN_CSV = "../input/cassava-leaf-disease-classification/train.csv"
TRAIN_IMAGE_FOLDER = '../input/cassava-leaf-disease-classification/train_images'
sPath = '../input/cassava-leaf-disease-classification/test_images/'


class TrainDataset(Dataset):
    def _init_(self, train, train_mode=True, transforms=None):
        self.train = train
        self.transforms = transforms
        self.train_mode = train_mode

    def _len_(self):
        return self.train.shape[0]

    def _getitem_(self, index):
        image_path = os.path.join(TRAIN_IMAGE_FOLDER, self.train.iloc[index].image_id)
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if (self.transforms):
            image = self.transforms(image=image)["image"]
        if not (self.train_mode):
            return {"x": image}
        return {
            "x": image,
            "y": torch.tensor(self.train.iloc[index, self.train.columns.str.startswith('label')], dtype=torch.float64)
        }


class TestDataset(Dataset):
    def _init_(self, test_df, transforms=None):
        self.test_df = test_df
        self.transforms = transforms

    def _len_(self):
        return self.test_df.shape[0]

    def _getitem_(self, index):
        image_path = os.path.join(sPath, self.test_df.iloc[index].image_id)
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if (self.transforms):
            image = self.transforms(image=image)["image"]
        return {
            "x": image
        }


## Adding in augmentations
def get_augmentations():
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    train_augmentations = albu.Compose([
        albu.RandomResizedCrop(*IMAGE_SIZE, p=1.0),
        albu.Transpose(p=0.5),
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(0.5),
        albu.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
        ToTensorV2(p=1.0)
    ], p=1.0)

    valid_augmentations = albu.Compose([
        albu.Resize(*IMAGE_SIZE),
        albu.Normalize(mean, std, max_pixel_value=255, always_apply=True),
        ToTensorV2(p=1.0)
    ], p=1.0)

    return train_augmentations, valid_augmentations


## Define data loading methods
train_augs, val_augs = get_augmentations()


## Building our model
learning_rate = 0.001

class MyNet(nn.Module):

    def _init_(self):
        super(MyNet, self)._init_()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(20 * 5 * 5, 50)
        self.fc2 = nn.Linear(50, OUTPUT_SIZE)
        self.localization = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True))
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 4 * 4, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2))
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def stn(self, x):
        print(x.size())
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 4 * 4)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        print(theta.size())
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

    def forward(self, x):
        x = self.stn(x)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 20 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# ST

with strategy.scope():
    my_net = MyNet()
    optimizer = torch.optim.Adam(my_net.parameters(), lr=learning_rate, weight_decay=1e-5)
    loss_fn = nn.CrossEntropyLoss()


## load data
trainDataFrame = pd.read_csv(TRAIN_CSV)
trainDataset = TrainDataset(trainDataFrame, transforms=train_augs)
trainData, validationData = random_split(trainDataset, [round(len(trainDataset)*0.8), round(len(trainDataset)*0.2)])
trainDataLoader = DataLoader(trainData, batch_size=64,  num_workers=4, shuffle=True)
validationDataLoader = DataLoader(validationData, batch_size=64,  num_workers=4, shuffle=True)

testDataFrame = pd.DataFrame()
testDataFrame['image_id'] = list(os.listdir('../input/cassava-leaf-disease-classification/test_images/'))
testDataset = TestDataset(testDataFrame, transforms=val_augs)
testDataLoader = DataLoader(testDataset, BATCH_SIZE, num_workers=4, shuffle=False)


training_results_my_net = np.zeros(len(trainDataLoader) * EPOCHS)
for t in range(EPOCHS):
    for index, data in enumerate(trainDataLoader):
        x = data.get("x")
        y = data.get("y")
        y_pred = my_net(x)
        loss = loss_fn(y_pred, y)
        training_results_my_net[t * len(trainDataLoader) + index] = loss.data
        print("my net", loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

plt.plot(np.arange(len(training_results_my_net)), training_results_my_net)
plt.ylabel('Training Loss')
plt.suptitle('My Network Train Loss')
plt.show()

validation_results_my_net = np.zeros(len(validationDataLoader))
for index, data in enumerate(validationDataLoader):
    x = data.get("x")
    y = data.get("y")
    y_pred = my_net(x)
    loss = loss_fn(y_pred, y)
    validation_results_my_net[index] = loss.data

plt.plot(np.arange(len(validation_results_my_net)), validation_results_my_net)
plt.ylabel('Validation Loss')
plt.suptitle('My Network Validation Loss')
plt.show()

test_results_my_net = np.zeros(len(testDataLoader))
for index, data in enumerate(testDataLoader):
    x = data.get("x")
    y = data.get("y")
    y_pred = my_net(x)
    loss = loss_fn(y_pred, y)
    test_results_my_net[index] = loss.data

plt.plot(np.arange(len(test_results_my_net)), test_results_my_net)
plt.ylabel('Test Loss')
plt.suptitle('My Network Test Loss')
plt.show()