In [None]:
!pip install byol-pytorch

In [37]:
import torch
from byol_pytorch import BYOL
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import tqdm
from PIL import Image

In [18]:
cat_to_idx = {'Cassava bacterial blight (cbb)': 0,
              'Cassava brown streak disease (cbsd)': 1,
              'Cassava green mottle (cgm)': 2,
              'Cassava mosaic disease (cmd)': 3,
              'Healthy': 4}

In [44]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

Mounted at /gdrive


In [20]:
class CustomDataset(Dataset):
    def __init__(self, data_path, transforms, mapping = cat_to_idx):
        self.data_path = data_path
        self.transforms = transforms
        self.mapping = cat_to_idx

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):

        im = Image.open(self.data_path.loc[idx, 'file_name'])
        label = int(self.mapping[self.data_path.loc[idx, 'classes']])

        if self.transforms:
            im = self.transforms(im)
        return im, label

In [21]:
means =  [0.485, 0.456, 0.406]
stds  =  [0.229, 0.224, 0.225]

train_transforms      =  transforms.Compose([transforms.RandomRotation(30),
                                          transforms.RandomResizedCrop(256),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.ToTensor(),
                                          transforms.Normalize(means, stds)])


val_transforms        =  transforms.Compose([transforms.Resize(256),
                                             transforms.CenterCrop(256),
                                             transforms.ToTensor(),
                                             transforms.Normalize(means, stds)])

train_data = pd.read_csv('/content/remo_train.csv')
validation_data = pd.read_csv('/content/remo_valid.csv')

In [22]:
train_dl = DataLoader(CustomDataset(data_path=train_data, transforms=train_transforms), batch_size=128, num_workers=4, pin_memory=True)
val_dl = DataLoader(CustomDataset(data_path=validation_data, transforms=val_transforms), batch_size = 50, num_workers=4, pin_memory=True)

In [46]:
resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

learner = learner.cuda()

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

In [47]:
num_epochs = 2

In [None]:
for _ in range(num_epochs):
   for im, label in tqdm.tqdm_notebook(train_dl):
      learner.train()
      loss = learner(im.cuda())
      opt.zero_grad()

      loss.backward()
      opt.step()
      
      learner.update_moving_average()

   print('Training Loss : {:.5f}'.format(loss.item()))
   with torch.no_grad():
      learner.eval()
     
      for im, labels in tqdm.tqdm_notebook(val_dl):
          out = learner(im.cuda())
          _, index = torch.max(out, 1)
          total += labels.size(0)
          correct_preds += (index == labels).sum().item()
     
   val_acc = 100 * (correct_preds / total)
   print('Validation Accuracy is: {:.2f}%'.format(val_acc))

In [None]:
for im, labels in tqdm.tqdm_notebook(val_dl):
  learner.online_predictor(im.unsqueeze(0).cuda())