### **Library Imports**

In [1]:
import os
import re
import cv2
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from time import time
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torchvision import models, transforms

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

### **Utilities and Constants**

In [2]:
SEED = 42
SIZE = 512
le = LabelEncoder()


def breaker(num: int=50, char: str="*") -> None:
    print("\n" + num*char + "\n")


def get_image(path: str, size: int) -> np.ndarray:
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(src=image, code=cv2.COLOR_BGR2RGB)
    return cv2.resize(src=image, dsize=(size, size), interpolation=cv2.INTER_AREA)


def show_graphs(L: list, A: list) -> None:
    TL, VL, TA, VA = [], [], [], []
    for i in range(len(L)):
        TL.append(L[i]["train"])
        VL.append(L[i]["valid"])
        TA.append(A[i]["train"])
        VA.append(A[i]["valid"])
    x_Axis = np.arange(1, len(TL) + 1)
    plt.figure(figsize=(8, 6))
    plt.subplot(1, 2, 1)
    plt.plot(x_Axis, TL, "r", label="Train")
    plt.plot(x_Axis, VL, "b", label="Valid")
    plt.legend()
    plt.grid()
    plt.title("Loss Graph")
    plt.subplot(1, 2, 2)
    plt.plot(x_Axis, TA, "r", label="Train")
    plt.plot(x_Axis, VA, "b", label="Valid")
    plt.legend()
    plt.grid()
    plt.title("Accuracy Graph")
    plt.show()

### **Configuration**

In [3]:
class CFG(object):
    def __init__(self, 
                 seed: int = 42,
                 n_splits: int = 5,
                 batch_size: int = 16,
                 epochs: int = 25,
                 early_stopping: int = 5,
                 lr: float = 1e-4,
                 wd: float = 0.0,
                 max_lr: float = 1e-3,
                 pct_start: float = 0.2,
                 steps_per_epoch: int = 100,
                 div_factor: int = 1e3, 
                 final_div_factor: float = 1e3,
                 ):
        
        self.seed = seed
        self.n_splits = n_splits
        self.batch_size = batch_size
        self.epochs = epochs
        self.early_stopping = early_stopping
        self.lr = lr
        self.wd = wd
        self.max_lr = max_lr
        self.pct_start = pct_start
        self.steps_per_epoch = steps_per_epoch
        self.div_factor = div_factor
        self.final_div_factor = final_div_factor
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.train_base_path = "../input/sorghum-id-fgvc-9/train_images"
        self.test_base_path = "../input/sorghum-id-fgvc-9/test"
        ### MODIFY
        self.train_transform = transforms.Compose([transforms.ToTensor(), 
                                                   transforms.Normalize([0.36878, 0.38273, 0.29333], 
                                                                        [0.16007, 0.16414, 0.12774]),])
        self.valid_transform = transforms.Compose([transforms.ToTensor(), 
                                                   transforms.Normalize([0.36878, 0.38273, 0.29333], 
                                                                        [0.16007, 0.16414, 0.12774]),])
        self.save_path = "saves"
        if not os.path.exists(self.save_path): os.makedirs(self.save_path)

    
cfg = CFG(seed=SEED)

### **Dataset Template**

In [4]:
class DS(Dataset):
    def __init__(self, base_path: str, filenames: np.ndarray, transform, labels: np.ndarray = None):
        self.base_path = base_path
        self.filenames = filenames
        self.transform = transform
        self.labels = labels
    
    def __len__(self):
        return self.filenames.shape[0]
    
    def __getitem__(self, idx):
        if self.labels is not None:
            return self.transform(get_image(os.path.join(self.base_path, self.filenames[idx]), SIZE)), torch.LongTensor(self.labels[idx])
        else:
            return self.transform(get_image(os.path.join(self.base_path, self.filenames[idx]), SIZE))

### **Model**

In [5]:
class MyResnet(nn.Module):
    def __init__(self):
        super(MyResnet, self).__init__()

        self.model = models.resnet50(pretrained=False, progress=False)
        self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=100)

    def forward(self, x):
        return nn.LogSoftmax(dim=1)(self.model(x))

### **Fit and Predict Helper**

In [6]:
def fit(model=None, 
        optimizer=None, 
        scheduler=None, 
        epochs=None, 
        early_stopping_patience=None, 
        dataloaders=None, 
        fold=None, 
        save_path=None,
        device=None,
        verbose=False) -> tuple:
    
    def get_accuracy(y_pred, y_true):
        y_pred = torch.argmax(y_pred, dim=1)
        return torch.count_nonzero(y_pred == y_true).item() / len(y_pred)
    
    if verbose:
        breaker()
        if fold: print(f"Training Fold {fold}...")
        else: print("Training ...")
        breaker()
        
    bestLoss, bestAccs = {"train" : np.inf, "valid" : np.inf}, {"train" : 0.0, "valid" : 0.0}
    Losses, Accuracies, LRs = [], [], []
    if fold: name = f"state_fold_{fold}.pt"
    else: name = "state.pt"

    start_time = time()
    for e in range(epochs):
        e_st = time()
        epochLoss, epochAccs = {"train" : 0.0, "valid" : 0.0}, {"train" : 0.0, "valid" : 0.0}

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
            
            lossPerPass, accsPerPass = [], []

            for X, y in dataloaders[phase]:
                X, y = X.to(device), y.to(device).view(-1)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    output = model(X)
                    loss = torch.nn.NLLLoss()(output, y)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        if scheduler: scheduler.step()
                lossPerPass.append(loss.item())
                accsPerPass.append(get_accuracy(output, y))
            epochLoss[phase] = np.mean(np.array(lossPerPass))
            epochAccs[phase] = np.mean(np.array(accsPerPass))
        LRs.append(scheduler.get_last_lr())
        Losses.append(epochLoss)
        Accuracies.append(epochAccs)
        
        if early_stopping_patience:
            if epochLoss["valid"] < bestLoss["valid"]:
                bestLoss = epochLoss
                BLE = e + 1
                torch.save({"model_state_dict"     : model.state_dict(),
                            "optim_state_dict"     : optimizer.state_dict(),
                            "scheduler_state_dict" : scheduler.state_dict()},
                           os.path.join(save_path, name))
                early_stopping_step = 0
            else:
                early_stopping_step += 1
                if early_stopping_step > early_stopping_patience:
                    print("\nEarly Stopping at Epoch {}".format(e + 1))
                    break
        
        if epochLoss["valid"] < bestLoss["valid"]:
            bestLoss = epochLoss
            BLE = e + 1
            torch.save({"model_state_dict"     : model.state_dict(),
                        "optim_state_dict"     : optimizer.state_dict(),
                        "scheduler_state_dict" : scheduler.state_dict()},
                        os.path.join(save_path, name))
        
        if epochAccs["valid"] > bestAccs["valid"]:
            bestAccs = epochAccs
            BAE = e + 1
        
        if verbose:
            print("Epoch: {} | Train Loss: {:.5f} | Valid Loss: {:.5f} |\
Train Accs: {:.5f} | Valid Accs: {:.5f} | Time: {:.2f} seconds".format(e+1, 
                                                                       epochLoss["train"], epochLoss["valid"], 
                                                                       epochAccs["train"], epochAccs["valid"], 
                                                                       time()-e_st))

    if verbose:                                           
        breaker()
        print(f"Best Validation Loss at Epoch {BLE}")
        breaker()
        print(f"Best Validation Accs at Epoch {BAE}")
        breaker()
        print("Time Taken [{} Epochs] : {:.2f} minutes".format(len(Losses), (time()-start_time)/60))
    
    return Losses, Accuracies, LRs, BLE, BAE, name


# def predict(model=None, dataloader=None, path=None, device=None) -> np.ndarray:
#     model.load_state_dict(torch.load(path, map_location=device)["model_state_dict"])
#     model.to(device)    
#     model.eval()
    
#     y_pred = torch.zeros(1, 1).to(device)
    
#     for X in dataloader:
#         X = X.to(device)
#         with torch.no_grad():
#             output = torch.argmax(torch.exp(model(X)), dim=1)
#         y_pred = torch.cat((y_pred, output.view(-1, 1)), dim=0)
    
#     return y_pred[1:].detach().cpu().numpy()

### **Train**

In [7]:
# train_df = pd.read_csv("../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv")
# train_df = train_df.drop(index=train_df.index[train_df.image == ".DS_Store"])

# filenames = train_df.image.copy().values
# labels = train_df.cultivar.copy().values 
# labels = le.fit_transform(labels)

In [8]:
# for tr_idx, va_idx in StratifiedKFold(n_splits=cfg.n_splits, random_state=cfg.seed, shuffle=True).split(filenames, labels):
#     break

# tr_filenames, va_filenames, tr_labels, va_labels = filenames[tr_idx], filenames[va_idx], labels[tr_idx], labels[va_idx]

# tr_data_setup = DS(cfg.train_base_path, tr_filenames, cfg.train_transform, tr_labels.reshape(-1, 1))
# va_data_setup = DS(cfg.train_base_path, va_filenames, cfg.valid_transform, va_labels.reshape(-1, 1))

# dataloaders = {
#     "train" : DL(tr_data_setup, batch_size=cfg.batch_size, shuffle=True, generator=torch.manual_seed(cfg.seed)),
#     "valid" : DL(va_data_setup, batch_size=cfg.batch_size, shuffle=False),
# }

# cfg = CFG(epochs=15, batch_size=20, steps_per_epoch=len(dataloaders["train"]))

# torch.manual_seed(cfg.seed)
# model = MyResnet().to(cfg.device)
# optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=cfg.lr, weight_decay=cfg.wd)
# scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, 
#                                           max_lr=cfg.max_lr, 
#                                           epochs=cfg.epochs, 
#                                           steps_per_epoch=cfg.steps_per_epoch,
#                                           pct_start=cfg.pct_start, 
#                                           div_factor=cfg.div_factor, 
#                                           final_div_factor=cfg.final_div_factor)

# L, A, LRs, BLE, BAE, name = fit(model=model, 
#                                 optimizer=optimizer, 
#                                 scheduler=scheduler, 
#                                 epochs=cfg.epochs, 
#                                 early_stopping_patience=cfg.early_stopping, 
#                                 dataloaders=dataloaders,  
#                                 save_path=cfg.save_path,
#                                 device=cfg.device,
#                                 verbose=True)

# breaker()
# show_graphs(L, A)

### **Submission**

In [9]:
# ss_df = pd.read_csv("/.data/sample_submission.csv")

# ts_data_setup = DS(cfg.test_base_path, ss_df.filename.copy().values, TRANSFORM)
# ts_data = DL(ts_data_setup, batch_size=cfg.batch_size, shuffle=False)

# y_pred = predict(model=MyResnet().to(cfg.device), dataloader=ts_data, path=MODEL_PATH)
# y_pred = le.inverse_transform(y_pred.astype("uint8"))

# ss_df["cultivar"] = y_pred
# ss_df.to_csv("/content/submission.csv", index=False)