In [None]:
!pip install timm

### Import Libraries

In [None]:
import os
import cv2
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import StratifiedKFold
import timm

In [None]:
import random
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

### Hyperparameters

In [None]:
FOLDS = 10
BATCH_SIZE =16
LR = 0.0001
EPOCHS=4
SMOOTHING = 0.1

LOSS_FUNCTION = F.mse_loss

IMG_SIZE = 240
IMG_SIZE = 400

EARLY_STOPPING = True

MODEL_ARCH = 'resnet50'
MODEL_ARCH = 'tf_efficientnet_b1_ns'
MODEL_ARCH = 'efficientnet_b3'
MODEL_ARCH = 'tf_efficientnet_b4_ns'

IMAGE_FOLDER = '../input/banana-count-and-weight-in-a-bunch/Images/Images'

### Load Data

In [None]:
banana_df = pd.read_csv('../input/banana-count-and-weight-in-a-bunch/Estu.csv')
banana_df.shape

### Simple EDA

In [None]:
banana_df.sample(9)

### Sample images
You can note that some images are in landscape mode. You may have to rotate them in your dataset

In [None]:
samples = banana_df.sample(9)
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(16, 16))
count=0
for row in ax:
    for col in row:
        col.imshow(plt.imread('../input/banana-count-and-weight-in-a-bunch/Images/Images/'+samples.iloc[count]['File Name']))
        col.set_title(f'Count:{samples.iloc[count]["Banana Count"]}, Weight:{samples.iloc[count]["Weight"]}kg')
        count += 1
plt.show()

### What is a banana bunch?  
As you can see in above images, each bunch is nothing but a complete fruit-set from the banana tree.

In [None]:
banana_df.describe()

There are 22 unique branches in the dataset. However there are 713 images to train on. While creating cross validation (CV) we have to ensure that images of the same bunch are not in training and validation set, so as to prevent a validation leak.

### KFold
Every fold will have 4 bunches and the last fold will have 6 bunches.

In [None]:
unique_branch_ids = banana_df['Bunch ID'].unique()
np.random.shuffle(unique_branch_ids)
banana_df.loc[banana_df['Bunch ID'].isin(unique_branch_ids[0:4]),'kfold'] = 0
banana_df.loc[banana_df['Bunch ID'].isin(unique_branch_ids[4:8]),'kfold'] = 1
banana_df.loc[banana_df['Bunch ID'].isin(unique_branch_ids[8:12]),'kfold'] = 2
banana_df.loc[banana_df['Bunch ID'].isin(unique_branch_ids[12:16]),'kfold'] = 3
banana_df.loc[banana_df['Bunch ID'].isin(unique_branch_ids[16:22]),'kfold'] = 4
banana_df.to_csv("train_folds.csv", index=False)

### Augmentations

In [None]:
def get_augmentations():
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225) 
    
    train_augmentations = albu.Compose([
        albu.RandomResizedCrop(IMG_SIZE, IMG_SIZE, p=1.0),
        albu.Transpose(p=0.5),
        albu.HorizontalFlip(p=0.5),
#         albu.VerticalFlip(0.5),
        albu.CoarseDropout (p=0.5),
        albu.Normalize(always_apply=True),        
        ToTensorV2(p=1.0)
    ], p=1.0)
    
    valid_augmentations = albu.Compose([
        albu.Resize(IMG_SIZE, IMG_SIZE),
        albu.Normalize(always_apply=True),        
        ToTensorV2(p=1.0)
    ], p=1.0)   
    
    return train_augmentations, valid_augmentations

train_augs, val_augs = get_augmentations()

### Dataset

In [None]:
class BananaDataset(Dataset):
    def __init__(self, data, is_testing=False, transforms=None):
        self.data = data
        self.is_testing = is_testing
        self.transforms = transforms
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index):        
        image_path = f"{IMAGE_FOLDER}/{self.data.iloc[index]['File Name']}"
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            transformed = self.transforms(image = image)
            image = transformed['image']  # this is of type tensor
            
        if self.is_testing:
            item = {
                "image_name": self.data.iloc[index]['File Name'],
                "image": image 
            }
        else:
            y_count = self.data.iloc[index]['Banana Count']
            y_weight = self.data.iloc[index]['Weight']
            item = {
                "image_name": self.data.iloc[index]['File Name'],
                "image": image,
                "weight": torch.tensor(y_weight, dtype = torch.float32),
                "count": torch.tensor(y_count, dtype = torch.float32)
            }
        return item

### NN Model - Multi output regressor

In [None]:
class BananaNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(MODEL_ARCH, pretrained=False)        
        n_features = 1792
        self.model.pooling = nn.AdaptiveAvgPool2d(1)
        self.model.weightRegressor = nn.Linear(n_features, 1)
        self.model.countRegressor = nn.Linear(n_features, 1)
    
    def forward(self, x):
        x = self.model.forward_features(x)
        x = self.model.pooling(x)
        x = x.flatten(1)
        pr_weight = self.model.weightRegressor(x)
        pr_count = self.model.countRegressor(x)
        return pr_weight.squeeze(), pr_count.squeeze()

### Pytorch Lightning Data Module
Prepares data and prepares dataloaders

In [None]:
class BananaDataModule(pl.LightningDataModule):
    def __init__(self, fold):
        super().__init__()
        self.train_aug, self.val_aug = get_augmentations()
        self.fold = fold
        
    def setup(self, stage=None):
        train_fold = banana_df.loc[banana_df['kfold'] != self.fold]
        val_fold = train_fold = banana_df.loc[banana_df['kfold'] == self.fold]
        self.train_ds = BananaDataset(train_fold, transforms = self.train_aug)
        self.val_ds = BananaDataset(val_fold, transforms = self.val_aug)
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, BATCH_SIZE, num_workers=4, shuffle=True)
        
    def val_dataloader(self):
        return DataLoader(self.val_ds, BATCH_SIZE, num_workers=4, shuffle=False)                

In [None]:
class BananaModule(pl.LightningModule):
    def __init__(self, hparams, model):
        super(BananaModule, self).__init__()
        self.hparams = hparams
        self.loss_function = LOSS_FUNCTION
        self.model = model
        self.accuracy = pl.metrics.Accuracy
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_index):
        weight_prediction, count_prediction = self(batch['image'])
#         import pdb; pdb.set_trace()
        weight_loss = self.loss_function(weight_prediction, batch['weight'])
        count_loss = self.loss_function(count_prediction, batch['count'])
        loss = weight_loss + count_loss
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True,logger=True)
        return loss
    
    def validation_step(self, batch, batch_index): 
        weight_prediction, count_prediction = self(batch['image'])
        weight_loss = self.loss_function(weight_prediction, batch['weight'])
        count_loss = self.loss_function(count_prediction, batch['count'])
        loss = weight_loss + count_loss
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True,logger=True)    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams['lr'])
        scheduler = {
            'scheduler': 
                torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                    optimizer, 
                    15,
                    verbose=False
                ),
            'interval': 'step',
            'monitor' : 'train_loss'
        }
        return [optimizer], [scheduler]          

### Train

In [None]:
def train(fold):
    callbacks = []
    
    checkpoint_cb = ModelCheckpoint(
        dirpath='checkpoints/',
        filename='model_{val_loss:.2f}',
        monitor='val_loss', verbose=True,
        save_last=False, save_top_k=1, save_weights_only=False,
        mode='min', period=1
    )
    callbacks.append(checkpoint_cb)
    
    early_stopping_cb = EarlyStopping('val_loss', patience=3, verbose=True, mode='min')
    callbacks.append(early_stopping_cb)
    
    trainer = pl.Trainer(
        gpus=-1,
        precision=16,
        max_epochs=EPOCHS,
        accumulate_grad_batches=1, # NEW NEW NEW NEW NEW NEW NEW NEW NEW
        callbacks=callbacks        
    )
    
    model = BananaNNModel()
    pl_dm = BananaDataModule(fold)
    pl_module = BananaModule(hparams={'lr':LR, 'batch_size':BATCH_SIZE}, model=model)
    trainer.fit(pl_module, pl_dm)

In [None]:
for fold in range(FOLDS):
    train(fold)