Training at : https://www.kaggle.com/krisho007/simple-gpu-pytorch-lightning-training

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')

In [None]:
import os
import cv2
import pandas as pd
import numpy as np
import random
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import StratifiedKFold
import timm

In [None]:
import random
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

In [None]:
TRAIN_CSV = "../input/cassava-leaf-disease-classification/train.csv"
TRAIN_IMAGE_FOLDER = '../input/cassava-leaf-disease-classification/train_images'
CLASSES = 5

### Hyper parameters

In [None]:
BATCH_SIZE =32
LR = 0.0001

IMG_SIZE = 128
IMG_SIZE = 240
IMG_SIZE = 512

MODEL_ARCH = 'resnet50'
MODEL_ARCH = 'tf_efficientnet_b1_ns'
MODEL_ARCH = 'efficientnet_b3'
MODEL_ARCH = 'tf_efficientnet_b4_ns'

TEST_IMAGES_PATH = '../input/cassava-leaf-disease-classification/test_images/'
# TEST_IMAGES_PATH = '../input/cassava-leaf-disease-classification/train_images/'

### Dataset

In [None]:
class CassavaTestDataset(Dataset):
    def __init__(self, test_df, transforms=None):
        self.test_df = test_df
        self.transforms = transforms
    
    def __len__(self):
        return self.test_df.shape[0]
    
    def __getitem__(self, index):
        image_path = os.path.join(TEST_IMAGES_PATH, self.test_df.iloc[index].image_id)
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if (self.transforms):
            image = self.transforms(image=image)["image"]
        
        return {
            "x": image
        }

### Transforms

In [None]:
def get_augmentations():
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)    
    
    test_augmentations = albu.Compose([
        albu.RandomResizedCrop(IMG_SIZE, IMG_SIZE, p=1.0),
        albu.Transpose(p=0.5),
        albu.HorizontalFlip(p=0.5),
        albu.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
        albu.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
        albu.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),        
        ToTensorV2(p=1.0)
    ], p=1.0)
    
    return test_augmentations

test_augmentations = get_augmentations()

### NN Model

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model(MODEL_ARCH, pretrained=False)
#         self.model = base_model

        # Efficientnets
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, CLASSES)
        
#         # Resnets
#         n_features = self.model.fc.in_features
#         self.model.fc = nn.Linear(n_features, CLASSES)
        
    def forward(self, x):
        x = self.model(x)
        return x

### PL Module

In [None]:
# You just need __init__ and forward methods of you LightningModule
class CassavaPLModule(pl.LightningModule):
    def __init__(self, hparams, model):
        super(CassavaPLModule, self).__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)       

### Inference

In [None]:
nnModel = Model()

# Test data loader
test_df = pd.DataFrame()
test_df['image_id'] = list(os.listdir(TEST_IMAGES_PATH))
# test_df = test_df[:64]

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [None]:
def inference(checkpoint_directory) :
    test_ds = CassavaTestDataset(test_df, transforms=test_augmentations)
    test_loader = DataLoader(test_ds, BATCH_SIZE, num_workers=4, shuffle=False)
    
    preds = []
    for batch in test_loader:
#         import pdb; pdb.set_trace()
        avg_preds = [] #average prediction per batch
        # Each batch has multiple images. Multiple predictions at a time
        input = batch['x']
        input = input.cuda()
        for modelWeight in os.scandir(checkpoint_directory):
            model = CassavaPLModule.load_from_checkpoint(f"{checkpoint_directory}/{modelWeight.name}", hparams={'lr':LR, 'batch_size':BATCH_SIZE}, model=nnModel)
            model.eval()
            model.cuda()
            model.freeze()  #Will get a CUDA memory error without this
            output = model(input)
            avg_preds.append(output.detach().to('cpu').numpy())
            break
        avg_preds = np.mean(avg_preds, axis=0)
        preds.append(avg_preds)
    preds = np.concatenate(preds)
    return preds                               

In [None]:
checkpoint_directory = '../input/simple-gpu-pytorch-lightning-training/checkpoints'
# checkpoint_directory = '../input/gpu-pytorch-lightning-training-on-inference/checkpoints'
# checkpoint_directory = '../input/trainedoninferencemodels'
testPredictions = inference(checkpoint_directory)

### 5 times TTA

#### Hardvoting

In [None]:
# hardVoting = None
# for k in range(5):
#     singlePrediction = inference(checkpoint_directory)
#     singlePrediction = (singlePrediction == singlePrediction.max(axis=1)[:,None]).astype(int)
#     try:
#         hardVoting += singlePrediction
#     except:
#         hardVoting = singlePrediction

#### Soft voting

In [None]:
# for k in range(5):
#     try:
#         testPredictions += inference(checkpoint_directory)
#     except NameError:
#         testPredictions = inference(checkpoint_directory)

In [None]:
test_df['label'] = testPredictions.argmax(1)
test_df.to_csv('submission.csv', index=False)

In [None]:
test_df

- Hard voting  
- Remove training images with low confidance