## Library Imports

In [None]:
import os
import re
import torch
import imgaug
import numpy as np
import matplotlib.pyplot as plt

from time import time
from torch import nn, optim
from imgaug import augmenters
from torchvision import models, transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from sklearn.model_selection import KFold

## Constants and Utiities

In [None]:
def breaker(num: int = 50, char: str = "*") -> None:
    print("\n" + num*char + "\n")


def show_graphs(L: list) -> None:
    TL, VL = [], []
    for i in range(len(L)):
        TL.append(L[i]["train"])
        VL.append(L[i]["valid"])

    x_Axis = np.arange(1, len(TL) + 1)
    plt.figure(figsize=(8, 6))
    plt.plot(x_Axis, TL, "r", label="Train")
    plt.plot(x_Axis, VL, "b", label="Valid")
    plt.legend()
    plt.grid()
    plt.title("Loss Graph")
    plt.show()

In [None]:
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRANSFORM_FINAL = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.42433, 0.42265, 0.42161], [0.22863, 0.22852, 0.22842])])

SAVE_PATH = "saves"
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

## Dataset Template

In [None]:
class DS(Dataset):
    def __init__(self, images: np.ndarray, targets: np.ndarray = None, transform=None, mode: str = "train"):

        assert re.match(r"^train$", mode, re.IGNORECASE) or re.match(r"^valid$", mode, re.IGNORECASE) or re.match(r"^test$", mode, re.IGNORECASE), "Invalid Mode"
        
        self.mode = mode
        self.transform = transform
        self.images = images

        if re.match(r"^train$", mode, re.IGNORECASE) or re.match(r"^valid$", mode, re.IGNORECASE):
            self.targets = targets

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        if re.match(r"^train$", self.mode, re.IGNORECASE) or re.match(r"^valid$", self.mode, re.IGNORECASE):
            return self.transform(self.images[idx]), torch.FloatTensor(self.targets[idx])
        else:
            return self.transform(self.images[idx])

## Model

In [None]:
class Model(nn.Module):
    def __init__(self, mode: str, model_name: str):
        super(Model, self).__init__()

        self.mode = mode
        self.model_name = model_name

        if re.match(r"^vgg$", self.model_name, re.IGNORECASE):
            if re.match(r"^full$", self.mode, re.IGNORECASE):
                self.model = models.vgg16_bn(pretrained=False, progress=True)
                self.model.classifier[-1] = nn.Linear(in_features=self.model.classifier[-1].in_features, out_features=4)
            elif re.match(r"^semi$", self.mode, re.IGNORECASE) or re.match(r"^final$", self.mode, re.IGNORECASE):
                self.model = models.vgg16_bn(pretrained=True, progress=True)
                self.freeze()
                self.model.classifier[-1] = nn.Linear(in_features=self.model.classifier[-1].in_features, out_features=4)
            
        elif re.match(r"^resnet$", self.model_name, re.IGNORECASE):
            if re.match(r"^full$", self.mode, re.IGNORECASE):
                self.model = models.resnet50(pretrained=False, progress=True)
                self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=4)
            elif re.match(r"^semi$", self.mode, re.IGNORECASE) or re.match(r"^final$", self.mode, re.IGNORECASE):
                self.model = models.resnet50(pretrained=True, progress=True)
                self.freeze()
                self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=4)
        
        elif re.match(r"^densenet$", self.model_name, re.IGNORECASE):
            if re.match(r"^full$", self.mode, re.IGNORECASE):
                self.model = models.densenet169(pretrained=False, progress=True)
                self.model.classifier = nn.Linear(in_features=self.model.classifier.in_features, out_features=4)
            elif re.match(r"^semi$", self.mode, re.IGNORECASE) or re.match(r"^final$", self.mode, re.IGNORECASE):
                self.model = models.densenet169(pretrained=False, progress=True)
                self.freeze()
                self.model.classifier = nn.Linear(in_features=self.model.classifier.in_features, out_features=4)
        
        elif re.match(r"^mobilenet$", self.model_name, re.IGNORECASE):
            if re.match(r"^full$", self.mode, re.IGNORECASE):
                self.model = models.mobilenet_v3_small(pretrained=False, progress=True)
                self.model.classifier[-1] = nn.Linear(in_features=self.model.classifier[-1].in_features, out_features=4)
            elif re.match(r"^semi$", self.mode, re.IGNORECASE) or re.match(r"^final$", self.mode, re.IGNORECASE):
                self.model = models.mobilenet_v3_small(pretrained=False, progress=True)
                self.freeze()
                self.model.classifier[-1] = nn.Linear(in_features=self.model.classifier[-1].in_features, out_features=4)

    def freeze(self):
        for params in self.parameters():
            params.requires_grad = False

        if re.match(r"^vgg$", self.model_name, re.IGNORECASE):
            if re.match(r"^semi$", self.mode, re.IGNORECASE):
                for names, params in self.named_parameters():
                    if re.match(r".*features.3[4-9].*", names, re.IGNORECASE) or re.match(r".*features.4[0-9].*", names, re.IGNORECASE) or re.match(r".*classifier.*", names, re.IGNORECASE):
                        params.requires_grad = True
        
        elif re.match(r"^resnet$", self.model_name, re.IGNORECASE):
            if re.match(r"^semi$", self.mode, re.IGNORECASE):
                for names, params in self.named_parameters():
                    if re.match(r".*layer4.*", names, re.IGNORECASE):
                        params.requires_grad = True
        
        elif re.match(r"^densenet$", self.model_name, re.IGNORECASE):
            if re.match(r"^semi$", self.mode, re.IGNORECASE):
                for names, params in self.named_parameters():
                    if re.match(r".*denseblock4.*", names, re.IGNORECASE) or re.match(r".*norm5.*", names, re.IGNORECASE):
                        params.requires_grad = True
        
        elif re.match(r"^mobilenet$", self.model_name, re.IGNORECASE):
            if re.match(r"^semi$", self.mode, re.IGNORECASE):
                for names, params in self.named_parameters():
                    if re.match(r".*features.9.*", names, re.IGNORECASE) or re.match(r".*features.1[0-2].*", names, re.IGNORECASE) or re.match(r".*classifier.*", names, re.IGNORECASE):
                        params.requires_grad = True

    def get_optimizer(self, lr: float = 1e-3, wd: float = 0.0):
        params = [p for p in self.parameters() if p.requires_grad]
        return optim.Adam(params, lr=lr, weight_decay=wd)
    
    def get_plateau_scheduler(self, optimizer=None, patience: int = 5, eps: float = 1e-8):
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=patience, eps=eps, verbose=True)

    def forward(self, x):
        return self.model(x)


def get_model(seed: int, mode: str, model_name: str):
    torch.manual_seed(seed)
    model = Model(mode, model_name).to(DEVICE)

    return model

## Load Data

In [None]:
start_time = time()

size = 224
images  = np.load(f"../input/fgvc7-images/train_images_{size}.npy")
targets = np.load(f"../input/fgvc7-images/targets_{size}.npy")

np.random.seed(SEED)
np.random.shuffle(images)

np.random.seed(SEED)
np.random.shuffle(targets)

breaker()
print("Time Taken to Load Data : {:.2f} minutes".format((time() - start_time)/60))
breaker()

## Fit Helper

In [None]:
def fit(model=None, optimizer=None, scheduler=None, epochs=None, early_stopping_patience=None, fold=None, dataloaders=None, verbose=False):
    
    if verbose:
        breaker()
        print(f"Training Fold {fold}...")
        breaker()

    bestLoss = {"train" : np.inf, "valid" : np.inf}
    Losses   = []
    name = f"state_fold_{fold}.pt"

    
    start_time = time()
    for e in range(epochs):

        e_st = time()
        epochLoss = {"train" : 0.0, "valid" : 0.0}

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
            
            lossPerPass = []

            for X,y in dataloaders[phase]:
                X, y = X.to(DEVICE), y.to(DEVICE)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    output = model(X)
                    loss = torch.nn.BCEWithLogitsLoss()(output, y)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                lossPerPass.append(loss.item())
            epochLoss[phase] = np.mean(np.array(lossPerPass))
        Losses.append(epochLoss)
        
        torch.save({"model_state_dict" : model.state_dict(),
                    "optim_state_dict" : optimizer.state_dict()},
                    os.path.join(SAVE_PATH, name))

        if early_stopping_patience:
            if epochLoss["valid"] < bestLoss["valid"]:
                bestLoss = epochLoss
                BLE = e + 1
                torch.save({"model_state_dict": model.state_dict(),
                            "optim_state_dict": optimizer.state_dict()},
                           os.path.join(SAVE_PATH, name))
                early_stopping_step = 0
            else:
                early_stopping_step += 1
                if early_stopping_step > early_stopping_patience:
                    print("\nEarly Stopping at Epoch {}".format(e + 1))
                    break
        
        if epochLoss["valid"] < bestLoss["valid"]:
            bestLoss = epochLoss
            BLE = e + 1
            torch.save({"model_state_dict" : model.state_dict(),
                        "optim_state_dict" : optimizer.state_dict()},
                        os.path.join(SAVE_PATH, name))
        
        if scheduler:
            scheduler.step(epochLoss["valid"])
        
        if verbose:
            print("Epoch: {} | Train Loss: {:.5f} | Valid Loss: {:.5f} | Time: {:.2f} seconds".format(e+1, epochLoss["train"], epochLoss["valid"], time()-e_st))

    if verbose:                                           
        breaker()
        print(f"Best Validation Loss at Epoch {BLE}")
        breaker()
        print("Time Taken [{} Epochs] : {:.2f} minutes".format(len(Losses), (time()-start_time)/60))
        breaker()
        print("Training Completed")
        breaker()
    
    return Losses, BLE, name

## Params

In [None]:
DEBUG = False
seed = SEED
batch_size = 64
lr = 1e-6
wd = 1e-5
fold = 1

if DEBUG:
    epochs = 2
    early_stopping = 5
    n_splits = 3
else:
    epochs = 100
    early_stopping = 25
    n_splits = 5

## Train

In [None]:
for tr_idx, va_idx in KFold(n_splits=n_splits, shuffle=True, random_state=seed).split(images):
    tr_images, tr_targets, va_images, va_targets = images[tr_idx], targets[tr_idx], images[va_idx], targets[va_idx]
    
    tr_data_setup = DS(tr_images, tr_targets, TRANSFORM)
    va_data_setup = DS(va_images, va_targets, TRANSFORM)
    
    del tr_images, tr_targets, va_images, va_targets

    dataloaders = {
        "train" : DL(tr_data_setup, batch_size=batch_size, shuffle=True, generator=torch.manual_seed(seed)),
        "valid" : DL(va_data_setup, batch_size=batch_size, shuffle=False)
    }
    
    torch.manual_seed(seed)
    model = Model(mode="semi", model_name="resnet").to(DEVICE)
    optimizer = model.get_optimizer(lr=lr, wd=wd)
    L, _, _ = fit(model=model, optimizer=optimizer, epochs=epochs, early_stopping_patience=early_stopping, dataloaders=dataloaders, fold=fold, verbose=True)
    show_graphs(L)
    
    fold += 1

breaker()