# Library Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
from torch.nn.utils import weight_norm as WN
import torch.nn.functional as F

import gc
from time import time

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

seed = 42

# Helper Functions

In [None]:
def breaker():
    print("\n" + 50*"-" + "\n")

def head(x, no_of_ele=5):
    print(x[:no_of_ele])

# Data Handling

**Loading Image Data**

In [None]:
images = np.load("../input/rccl-1x144x144/images_1x144x144.npy")
labels = np.load("../input/rccl-1x144x144/labels_1x144x144.npy")

tr_images, va_images, tr_labels, va_labels = train_test_split(images, 
                                                              labels, 
                                                              test_size=0.2, 
                                                              shuffle=True, 
                                                              random_state=seed)

del images, labels

breaker()
print("Garbage Collected : {}".format(gc.collect()))
breaker()

**Dataset Template**

In [None]:
class Dataset(Dataset):
    def __init__(this, X=None, y=None, mode="train"):
        this.mode = mode
        this.X = X
        if mode == "train":
            this.y = y
            
    def __len__(this):
        return this.X.shape[0]
    
    def __getitem__(this, idx):
        if this.mode == "train":
            return torch.FloatTensor(this.X[idx]), torch.FloatTensor(this.y[idx])
        else:
            return torch.FloatTensor(this.X[idx])

# CNN Configuration and Setup

**Config**

In [None]:
class CFG():
    tr_batch_size = 128 # Alos va_batch_size
    ts_batch_size = 128
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    in_channels = 1
    OL = 11
    
    def __init__(this, filter_sizes=[64, 128, 256, 512], HL=[4096, 4096], epochs=50, n_folds=5):
        this.filter_sizes = filter_sizes
        this.HL = HL
        this.epochs = epochs
        this.n_folds = n_folds

**Setup**

In [None]:
class CNN(nn.Module):
    def __init__(this, in_channels=1, filter_sizes=None, HL=None, OL=None, use_DP=False, DP1=0.2, DP2=0.5):
        super(CNN, this).__init__()
        
        this.use_DP = use_DP
        
        this.DP1 = nn.Dropout(p=0.2)
        this.DP2 = nn.Dropout(p=0.5)
        
        this.MP_ = nn.MaxPool2d(kernel_size=2)
        
        this.CN1 = nn.Conv2d(in_channels=in_channels, out_channels=filter_sizes[0], kernel_size=3, stride=1, padding=1)
        this.BN1 = nn.BatchNorm2d(num_features=filter_sizes[0], eps=1e-5)
        
        this.CN2 = nn.Conv2d(in_channels=filter_sizes[0], out_channels=filter_sizes[1], kernel_size=3, stride=1, padding=1)
        this.BN2 = nn.BatchNorm2d(num_features=filter_sizes[1], eps=1e-5)
        
        this.CN3 = nn.Conv2d(in_channels=filter_sizes[1], out_channels=filter_sizes[2], kernel_size=3, stride=1, padding=1)
        this.BN3 = nn.BatchNorm2d(num_features=filter_sizes[2], eps=1e-5)
    
        this.CN4 = nn.Conv2d(in_channels=filter_sizes[2], out_channels=filter_sizes[3], kernel_size=3, stride=1, padding=1)
        this.BN4 = nn.BatchNorm2d(num_features=filter_sizes[3], eps=1e-5)
        
        this.CN5 = nn.Conv2d(in_channels=filter_sizes[3], out_channels=filter_sizes[3], kernel_size=3, stride=1, padding=1)
        this.BN5 = nn.BatchNorm2d(num_features=filter_sizes[3], eps=1e-5)
        
        this.CN6 = nn.Conv2d(in_channels=filter_sizes[3], out_channels=filter_sizes[3], kernel_size=3, stride=1, padding=1)
        this.BN6 = nn.BatchNorm2d(num_features=filter_sizes[3], eps=1e-5)
        
        this.FC1 = nn.Linear(in_features=filter_sizes[3]*2*2, out_features=HL[0])
        this.FC2 = nn.Linear(in_features=HL[0], out_features=HL[1])
        this.FC3 = nn.Linear(in_features=HL[1], out_features=OL)
        
    def getOptimizer(this, A_S=True, lr=1e-3, wd=0):
        if A_S:
            return optim.Adam(this.parameters(), lr=lr, weight_decay=wd)
        else:
            return optim.SGD(this.parameters(), lr=lr, momentum=0.9, weight_decay=wd)

    def getStepLR(this, optimizer=None, step_size=5, gamma=0.1):
        return optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=step_size, gamma=gamma)

    def getMultiStepLR(this, optimizer=None, milestones=None, gamma=0.1):
        return optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=gamma)
        
    def getPlateauLR(this, optimizer=None, patience=5, eps=1e-6):
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, patience=patience, eps=eps, verbose=True)
    
    def forward(this, x):
        if not this.use_DP:
            x = F.relu(this.MP_(this.BN1(this.CN1(x))))
            x = F.relu(this.MP_(this.BN2(this.CN2(x))))
            x = F.relu(this.MP_(this.BN3(this.CN3(x))))
            x = F.relu(this.MP_(this.BN4(this.CN4(x))))
            x = F.relu(this.MP_(this.BN5(this.CN5(x))))
            x = F.relu(this.MP_(this.BN6(this.CN6(x))))
            
            x = x.view(x.shape[0], -1)
            
            x = F.relu(this.FC1(x))
            x = F.relu(this.FC2(x))
            x = this.FC3(x)
            
            return x
        else:
            x = F.relu(this.MP_(this.BN1(this.CN1(x))))
            x = F.relu(this.MP_(this.BN2(this.CN2(x))))
            x = F.relu(this.MP_(this.BN3(this.CN3(x))))
            x = F.relu(this.MP_(this.BN4(this.CN4(x))))
            x = F.relu(this.MP_(this.BN5(this.CN5(x))))
            x = F.relu(this.MP_(this.BN6(this.CN6(x))))
            
            x = x.view(x.shape[0], -1)
            
            x = F.relu(this.DP2(this.FC1(x)))
            x = F.relu(this.DP2(this.FC2(x)))
            x = this.FC3(x)
            
            return x

**Train Function**

In [None]:
def fit_(model=None, optimizer=None, scheduler=None, epochs=None, early_stopping_patience=5, 
         trainloader=None, validloader=None, criterion=None, device=None, verbose=False):
    
    breaker()
    print("Training ...")
    breaker()
    
    # model.to(device)
    Losses = []

    DLS = {"train" : trainloader, "valid" : validloader}
    bestLoss = {"train" : np.inf, "valid" : np.inf}
    
    start_time = time()
    for e in range(epochs):
        e_st = time()
        epochLoss = {"train" : 0.0, "valid" : 0.0}
        
        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
                
            lossPerPass  = []
            
            for X, y in DLS[phase]:
                X, y = X.to(device), y.to(device)
                
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    output = model(X)
                    loss = criterion(output, y)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                lossPerPass.append(loss.item())
            epochLoss[phase] = np.mean(np.array(lossPerPass))
        Losses.append(epochLoss)
        
        """if epochLoss["valid"] < bestLoss["valid"]:
            bestLoss = epochLoss
            name = "./Epoch_{}.pt".format(e+1)
            torch.save(model.state_dict(), name)
            early_stopping_step = 0
            bestEpoch = e+1
        else:
            early_stopping_step += 1
            if early_stopping_step > early_stopping_patience:
                breaker()
                print("Early Stopping at Epoch {} - Best Valid Loss {:.5f} at Epoch {}".format(e+1, bestLoss["valid"], bestEpoch))
                break"""
        
        torch.save(model.state_dict(), "./Epoch_{}.pt".format(e+1))
        
        if epochLoss["valid"] < bestLoss["valid"]:
            bestLoss = epochLoss
            bestEpoch = e+1

        if scheduler:
            # scheduler.step()
            scheduler.step(epochLoss["valid"])
            
        if verbose:
            print("Epoch: {} | Train Loss: {:.5f} | Valid Loss: {:.5f} | Time: {:.2f} seconds".format(e+1, epochLoss["train"], epochLoss["valid"], time()-e_st))

    breaker()
    print("-----> Best Validation Loss at Epoch {}".format(bestEpoch))
    breaker()
    print("Time Taken [{} Epochs] : {:.2f} minutes".format(epochs, (time()-start_time)/60))
    breaker()
    print("Training Complete")
    breaker()

    return Losses, bestEpoch

**Training**

In [None]:
cfg = CFG(filter_sizes=[64, 128, 256, 512], HL=[4096, 4096], epochs=50, n_folds=5)

tr_data_setup = Dataset(tr_images, tr_labels)
va_data_setup = Dataset(va_images, va_labels)

tr_data = DL(tr_data_setup, batch_size=cfg.tr_batch_size, shuffle=True, generator=torch.manual_seed(seed))
va_data = DL(va_data_setup, batch_size=cfg.tr_batch_size, shuffle=False)

del tr_data_setup, va_data_setup

breaker()
print("Garbage Collected : {}".format(gc.collect()))

torch.manual_seed(seed)

model = CNN(filter_sizes=cfg.filter_sizes, HL=cfg.HL, OL=cfg.OL, use_DP=True).to(cfg.device)
optimizer = model.getOptimizer(lr=1e-3, wd=1e-5)

# scheduler = model.getStepLr(optimizer=optimizer, step_size=5, gamma=0.1)
# scheduler = model.getMultiStepLR(optimizer=optimizer, milestones=[10, 20, 30, 40], gamma=0.1)
scheduler = model.getPlateauLR(optimizer=optimizer, patience=5, eps=1e-8)

Losses, bestEpoch = fit_(model=model, optimizer=optimizer, scheduler=scheduler, epochs=cfg.epochs, early_stopping_patience=5,
                         trainloader=tr_data, validloader=va_data, criterion=nn.BCEWithLogitsLoss(), device=cfg.device,
                         verbose=True)

**Loss Plot**

In [None]:
LT = []
LV = []

for i in range(len(Losses)):
    LT.append(Losses[i]["train"])
    LV.append(Losses[i]["valid"])

plt.figure(figsize=(8, 6))
plt.plot([i+1 for i in range(len(LT))], LT, "r", label="Training Loss")
plt.plot([i+1 for i in range(len(LV))], LV, "b--", label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid()
plt.show()