In [None]:
%%capture
!pip install deepspeed
!pip install --upgrade wandb

In [None]:
import json
import multiprocessing as mp
from pathlib import Path
from typing import Any, Callable, List, Tuple

from deepspeed.ops.adam import FusedAdam
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import io, models, transforms
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm

# Wandb login:
from kaggle_secrets import UserSecretsClient
import wandb
user_secrets = UserSecretsClient()
secret_value = user_secrets.get_secret("wandb_api_key")
wandb.login(key=secret_value)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

%matplotlib inline
print(torch.__version__, pl.__version__, wandb.__version__)

In [None]:
ROOT_PATH = Path("/kaggle/input/cassava-leaf-disease-classification/")
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 64
LR = 1e-3
EPOCHS = 3

## Data

In [None]:
df = pd.read_csv(ROOT_PATH / "train.csv")
with open(ROOT_PATH / "label_num_to_disease_map.json", "r") as f:
    label_map = json.load(f)
label_map = {int(k): v for k, v in label_map.items()}
train_df, valid_df = train_test_split(df, stratify=df["label"].values)
    
plt.figure(figsize=(12, 5))
print(df["label"].map(label_map).value_counts())
df.sample(5)

In [None]:
class Data(Dataset):
    def __init__(self, df: pd.DataFrame, transforms=None):
        self.files = [ROOT_PATH / "train_images" / file for file in df["image_id"].values]
        self.y = df["label"].values.tolist()
        self.transforms = transforms
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        img = Image.open(self.files[i])
        label = self.y[i]
        if self.transforms is not None:
            img = self.transforms(img)
            
        return img, label

In [None]:
train_tfms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.4766, 0.4527, 0.3926], [0.2275, 0.2224, 0.2210])
    ]
)

valid_tfms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.4766, 0.4527, 0.3926], [0.2275, 0.2224, 0.2210])
    ]
)

train_ds = Data(train_df, train_tfms)
valid_ds = Data(valid_df, valid_tfms)

train_dl = DataLoader(
    train_ds,
    BATCH_SIZE, 
    shuffle=True, 
    drop_last=True, 
    num_workers=4,
    pin_memory=True,
)

valid_dl = DataLoader(
    valid_ds, 
    BATCH_SIZE*2, 
    shuffle=False, 
    drop_last=False, 
    num_workers=4,
    pin_memory=True,
)

In [None]:
x, y = next(iter(train_dl))
x.shape, y.shape

## Model

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.base = models.resnet34(pretrained=True)
        self.linear1 = nn.Linear(self.base.fc.in_features, self.base.fc.in_features // 2)
        self.linear2 = nn.Linear(self.base.fc.in_features // 2, num_classes)
        self.norm1 = nn.BatchNorm1d(self.base.fc.in_features)
        self.norm2 = nn.BatchNorm1d(self.base.fc.in_features // 2)
        self.dropout1 = nn.Dropout(p=0.5)
        self.dropout2 = nn.Dropout(p=0.5)
        self.base.fc = nn.Identity()
        
        for p in self.base.parameters():
            p.requires_grad = False
        
    def forward(self, x):
        out1 = self.dropout1(self.norm1(F.leaky_relu(self.base(x))))
        out2 = self.dropout2(self.norm2(F.leaky_relu(self.linear1(out1))))
        out3 = self.linear2(out2)
        
        return out3

In [None]:
class LightningModel(pl.LightningModule):
    def __init__(self, model: nn.Module, loss_fn: Callable, lr: float):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.lr = lr
        
    def common_step(self, batch):
        x, y = batch
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        accuracy = (logits.argmax(-1) == y).float().mean()

        return loss, accuracy
        
    def training_step(self, batch: Tuple[torch.FloatTensor, torch.LongTensor], *args: List[Any]):
        loss, accuracy = self.common_step(batch)
        self.log("training_loss", loss, on_step=True, on_epoch=True)
        self.log("training_accuracy", accuracy, on_step=True, on_epoch=True)
        
        return loss
        
    def on_epoch_end(self, *args):
        if self.current_epoch == 0:
            for p in self.model.base.parameters():
                p.requires_grad = True
        
    def validation_step(self, batch: Tuple[torch.FloatTensor, torch.LongTensor], *args: List[Any]):
        loss, accuracy = self.common_step(batch)
        self.log("validation_loss", loss, on_step=False, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_step=False, on_epoch=True)
        
    def configure_optimizers(self):
        return FusedAdam(self.model.parameters(), lr=self.lr)

## Training

In [None]:
label_counts = train_df["label"].value_counts().sort_index()
class_weights = max(label_counts) / label_counts.values
label_counts, class_weights

In [None]:
!mkdir /kaggle/working/logs
model = Model(df["label"].nunique())
loss_fn = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights))
lightning_model = LightningModel(model, loss_fn, LR)

logger = WandbLogger("cassava-1", "/kaggle/working/logs/", project="Kaggle-Cassava")
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=1.0,
    logger=logger,
    precision=16,
)
trainer.fit(lightning_model, train_dl, valid_dl)

In [None]:
class TestData(Dataset):
    def __init__(self, transforms=None):
        self.files = [path for path in (ROOT_PATH / "test_images").glob("*.jpg")]
        self.transforms = transforms
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, i):
        img = Image.open(self.files[i])
        if self.transforms is not None:
            img = self.transforms(img)
            
        return img
    
test_ds = TestData(valid_tfms)
test_dl = DataLoader(
    test_ds, 
    BATCH_SIZE*2, 
    shuffle=False, 
    drop_last=False, 
    num_workers=4,
    pin_memory=True,
)

model = model.eval().to(device)
y_preds = []
with torch.no_grad():
    for x in tqdm(test_dl):
        y_preds.append(model(x.to(device)).argmax(dim=-1).cpu())

y_preds = torch.cat(y_preds).cpu().numpy()
file_names = [test_ds.files[i].name for i in range(len(test_ds))]
pd.DataFrame({"image_id": file_names, "label": y_preds}).to_csv("./submission.csv")

In [None]:
y