# Plant 2021 with PyTorch Lightning
This notebook uses the models learned in the following notebooks for inference.
[Training notebook](https://www.kaggle.com/pegasos/plant2021-pytorch-lightning-starter-training)

In [None]:
package_paths = [
    '../input/pytorch-image-library/pytorch-image-models-master/pytorch-image-models-master',
]
import sys;

for pth in package_paths:
    sys.path.append(pth)

import timm

# Import

In [None]:
import pandas as pd
import numpy as np
import cv2
import torch
import torch.nn as nn
import albumentations as A
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from albumentations.core.composition import Compose, OneOf
from albumentations.augmentations.transforms import CLAHE, GaussNoise, ISONoise
from albumentations.pytorch import ToTensorV2

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import Callback
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn.model_selection import StratifiedKFold

# Config

In [None]:
class CFG:
    seed = 42
    model_name = 'resnet50'
    pretrained = False
    img_size = 512
    num_classes = 12
    batch_size = 32
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load images that have been pre-resized by AnkurSingh to speed up the learning process. https://www.kaggle.com/c/plant-pathology-2021-fgvc8/discussion/227032

In [None]:
PATH = "../input/plant-pathology-2021-fgvc8/"
TEST_DIR = PATH + 'test_images/'

In [None]:
seed_everything(CFG.seed)

In [None]:
df_all = pd.read_csv(PATH + "train.csv")
labels = list(df_all['labels'].value_counts().keys())
labels_dict = dict(zip(labels, range(12)))

In [None]:
sub = pd.read_csv(PATH + "sample_submission.csv")
sub.head()

# Define Dataset

In [None]:
class PlantDataset(Dataset):
    def __init__(self, df, transform=None):
        self.image_id = df['image'].values
        self.labels = df['labels'].values
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_id = self.image_id[idx]
        label = self.labels[idx]
        
        image_path = TEST_DIR + image_id
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        augmented = self.transform(image=image)
        image = augmented['image']
        return {'image':image, 'target': label}

In [None]:
def get_transform(phase: str):
    if phase == 'train':
        return Compose([
            A.RandomResizedCrop(height=CFG.img_size, width=CFG.img_size),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.Normalize(),
            ToTensorV2(),
        ])
    else:
        return Compose([
            A.Resize(height=CFG.img_size, width=CFG.img_size),
            A.Normalize(),
            ToTensorV2(),
        ])

In [None]:
test_dataset = PlantDataset(sub, get_transform('valid'))
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=2)

# Define Model

In [None]:
class CustomResNet(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.get_classifier().in_features
        self.model.fc = nn.Linear(in_features, CFG.num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
from collections import OrderedDict

def fix_model_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if name.startswith('model.'):
            name = name[6:]  # remove 'model.' of dataparallel
        new_state_dict[name] = v
    return new_state_dict

In [None]:
model = CustomResNet(model_name=CFG.model_name, pretrained=CFG.pretrained)

In [None]:
checkpoint = "../input/plat2021-resnet50/last.ckpt"

weight = torch.load(checkpoint)['state_dict']
model.load_state_dict(fix_model_state_dict(weight))

# Inference

In [None]:
model.cuda()
model.eval()

predictions = []
for batch in test_loader:
    image = batch['image'].cuda()
    with torch.no_grad():
        outputs = model(image)
        preds = outputs.argmax(1).detach().cpu().numpy()
        predictions.append(preds)

In [None]:
inv_labels_dict = {v: k for k, v in labels_dict.items()}
inv_labels_dict

In [None]:
sub['labels'] = np.concatenate(predictions)
sub = sub.replace({"labels": inv_labels_dict})
sub.to_csv('submission.csv', index=False)
sub.head()