In [None]:
!pip install byol-pytorch==0.5.2
!pip install pytorch-lightning

In [None]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

Mounted at /gdrive


In [None]:
import torch
from byol_pytorch import BYOL
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import tqdm
from PIL import Image
import multiprocessing
import pytorch_lightning as pl

In [None]:
# Python mports
import os

from PIL import Image

# Misc Python imports
import pandas as pd
import numpy as np

# PyTorch imports
import torch
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.functional import cross_entropy


# Remo
#import remo
#remo.set_viewer('jupyter')

In [None]:
cat_to_idx = {'Cassava bacterial blight (cbb)': 0,
              'Cassava brown streak disease (cbsd)': 1,
              'Cassava green mottle (cgm)': 2,
              'Cassava mosaic disease (cmd)': 3,
              'Healthy': 4}

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_path, transforms, mapping = cat_to_idx):
        self.data_path = data_path
        self.transforms = transforms
        self.mapping = cat_to_idx

    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):

        im = Image.open(self.data_path.loc[idx, 'file_name'])
        label = int(self.mapping[self.data_path.loc[idx, 'classes']])

        if self.transforms:
            im = self.transforms(im)
        return im, label

In [None]:
means =  [0.485, 0.456, 0.406]
stds  =  [0.229, 0.224, 0.225]

"""train_transforms      =  transforms.Compose([transforms.RandomRotation(30),
                                          transforms.RandomResizedCrop(256),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.ToTensor(),
                                          transforms.Normalize(means, stds)])


val_transforms        =  transforms.Compose([transforms.Resize(256),
                                             transforms.CenterCrop(256),
                                             transforms.ToTensor(),
                                             transforms.Normalize(means, stds)])"""

train_transforms = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
val_transforms = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])

train_data = pd.read_csv('/content/remo_train.csv')
validation_data = pd.read_csv('/content/remo_valid.csv')

In [None]:
class BYOL_Supervised(pl.LightningModule):
    def __init__(self, model, model_path, num_classes = 5, pretrained=False):
        
        super(BYOL_Supervised, self).__init__()
        
        self.model = model
        self.model_path = model_path

        if pretrained:
            self.model.load_state_dict(torch.load(self.model_path))
            
        
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    

    def training_step(self, batch, batch_idx):
        x, y = batch
        
        out = self.model(x)
        
        loss = cross_entropy(out, y)
        train_acc = accuracy(out, y)
        
        self.log('Training Loss', loss)
        self.log('Training Accuracy', train_acc)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        preds = self.model(x)
        
        val_loss = cross_entropy(preds, y)
        val_acc = accuracy(preds, y)
        
        self.log('Validation Loss', val_loss)
        self.log('Validation Accuracy', val_acc)
    
    def forward(self, x):
        with torch.no_grad():
            out = self.model(x)
            
        return out
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=2e-05)
        
        return optimizer

In [None]:
resnet = models.resnet50()
model = BYOL_Supervised(model = resnet, model_path = '/gdrive/MyDrive/cass/resnet_improved-net.pt')

In [None]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
trainer = pl.Trainer(max_epochs=8, 
                     gpus=1,
                     flush_logs_every_n_steps = 100)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
train_dl = DataLoader(CustomDataset(data_path=train_data, transforms=train_transforms), batch_size=50, num_workers=2, pin_memory=True)
val_dl = DataLoader(CustomDataset(data_path=validation_data, transforms=val_transforms), batch_size = 50, num_workers=2, pin_memory=True)

In [None]:
trainer.fit(model, train_dl, val_dl)


  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 23.5 M
---------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
Image.open('/gdrive/MyDrive/cass/train_images/1047550741.jpg')