# Starter for beginners-Pytorch
### This notebook is written by a beginner for beginners. Classes implemented in this notebook:
#### * RetrievalData ---> Class for loading dataset.
#### * CheckPoint    ---> Class for EarlyStopping and saving models/results.
#### * Hashtag       ---> Class for training, testing.

#### * Train-Test split using sklearn train_test_split.
#### * Computes Loss, Accuracy and ROC-AUC for both training and testing.



#### Suggestion/feedback would be highly appreciated. Thanks

# Haasha Bin Atif

In [None]:
#torch & torchvision
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset,DataLoader

#sklearn for splitting dataset and scoring functions
from sklearn.metrics import accuracy_score,roc_auc_score
from sklearn.model_selection import train_test_split


#Other
import matplotlib.pyplot as plt
from time import perf_counter
from scipy import stats
from PIL import Image
import pandas as pd
import numpy as np
import random
import pickle
import json
import os

In [None]:
#HyperParameters and Paths for Train & Test Dataset.
BATCH_SIZE = 64
NUM_WORKERS = 4
GPU = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
TRAIN_DIR = '../input/cassava-leaf-disease-classification/train_images/'
TEST_DIR = '../input/cassava-leaf-disease-classification/test_images/'

In [None]:
#helper functions
def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def display_images(images, title=None): 
    f, ax = plt.subplots(5,5, figsize=(18,22))
    if title:
        f.suptitle(title, fontsize = 30)

    for i, (image_id,label) in enumerate(images):
        image_path = os.path.join(TRAIN_DIR,image_id)
        image = Image.open(image_path)
        
        ax[i//5, i%5].imshow(image) 
        image.close()       
        ax[i//5, i%5].axis('off')
        ax[i//5, i%5].set_title(labels[str(label)], fontsize="10")
        
    plt.show()
def getModel(Name, OutFeatures):
    if Name == "AlexNet":
        model = torchvision.models.alexnet()
        model.classifier[6] = nn.Linear(4096,OutFeatures,bias=True)
    elif Name=="VGG16":
        model = torchvision.models.vgg16()
        model.classifier[6] = nn.Linear(4096,OutFeatures,bias=True)
    elif Name=="resnet152":
        model = torchvision.models.resnet152()
        model.fc = nn.Linear(2048,OutFeatures,bias=True)
    return model

def display(Type,epochNum,totalEpochs,Results):
    my_formatter = "{0:.6f}"
    print( ''.join([Type," Epoch#",str(epochNum+1).zfill(3),'/',str(totalEpochs).zfill(3)]) ,end='' )
    for key in Results.keys():
        print(''.join([" ",key,my_formatter.format(Results[key])]),end='')
    print()


In [None]:
#reading train.csv and submission.csv and loading labels
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
print("Train File:\n",train.head())
print("\nSubmission File:\n",submission.head())
labels = json.load(open("../input/cassava-leaf-disease-classification/label_num_to_disease_map.json"))

In [None]:
Images_Per_Class = np.unique(train['label'].tolist(),return_counts=True)
print("Description of Dataset")
print("Number of Classes:",len(train.label.unique()))
print("#OfImages Per Cassava Disease Class")
print("     Min:",min(Images_Per_Class[1]))
print("     Max:",max(Images_Per_Class[1]))
print("     Mean:",np.mean(Images_Per_Class[1]))
print("     Median:",np.median(Images_Per_Class[1]))
print("     Mode:",stats.mode(Images_Per_Class[1])[0][0])

print("\n")
print(train.head())

print("\nCLASS ---->       LABEL")
for i in labels:
    print(i,'    ---->',labels[i])

fig = plt.figure(figsize = (10, 5))
plt.bar(labels.keys(), Images_Per_Class[1], color ='blue', width = 0.8) 
plt.show()

In [None]:
print("Displaying some Images randomly...")
samples = train.sample(25)
display_images( zip(samples.image_id.values,samples.label.values))

In [None]:
class RetrievalData(Dataset):
    def __init__(self, Directory, FileNames, CorrectLabels,Transform, labels):
        self.directory = Directory
        self.filenames = FileNames
        self.transform = Transform
        self.correctlabels = CorrectLabels
        self.labels = labels
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self,index):
        x = Image.open(os.path.join(self.directory,self.filenames[index]))
        if "train" in self.directory:
            if self.transform is not None:
                return self.transform(x),self.correctlabels[index]
            return x,self.correctlabels[index]
        elif "test" in self.directory:
            if self.transform is not None:
                return self.transform(x),self.filenames[index]
            return x,self.filenames[index]

In [None]:
class CheckPoint():
    def __init__(self,Parameters):
        self.Count = 0
        self.BestLoss = float('inf')
        self.BestEpoch = -1
        
        self.Patience = Parameters["Patience"]
        self.Path = Parameters["SavePath"]
        self.earlyStopping=Parameters["earlyStopping"]

    def check(self,epoch,loss):
        torch.save({"Model":self.Model.state_dict(),"Optimizer":self.Optimizer.state_dict()},self.Path+"/Model.pth")
        if loss>self.BestLoss:
            self.Count+=1
        else:
            self.Count=0
            self.BestLoss = loss
            self.BestEpoch = epoch
            torch.save({"Model":self.Model.state_dict(),"Optimizer":self.Optimizer.state_dict()},self.Path+"/BestModel.pth")
        with open(self.Path+"/Results.txt", 'wb') as file:
            pickle.dump(self.Results,file)
        if self.earlyStopping:
            if self.Count==self.Patience:
                print("\nEarly Stopping!")
                print("Model didn't improved for",self.Patience,"epochs.")
                return False
            return True
        else:
            return True

In [None]:
class Hashtag(CheckPoint):
    def __init__(self, Parameters):
        CheckPoint.__init__(self,Parameters["CheckPoint"])
        self.Model = Parameters["Model"]
        self.Criterion = Parameters["Criterion"]
        self.Optimizer = Parameters["Optimizer"]
        self.TrainLoader = Parameters["TrainLoader"]
        self.ValidateLoader = Parameters["ValidateLoader"]
        self.Labels = Parameters["Labels"]
        self.Device = Parameters["Device"]
        self.Results = Parameters["Results"]
        self.Softmax = nn.Softmax(dim=1)
        self.ontoDevice()

        
    def ontoDevice(self):
        self.Model.to(DEVICE)
        self.Criterion.to(DEVICE)

    def train(self):
        self.Model.train()
        running_loss = 0.0
        Predictions = []
        TrueLabels = []
        PredictionProb = []
        for batch,data in enumerate(self.TrainLoader):
            images,labels = data
            TrueLabels.extend(labels)
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            self.Optimizer.zero_grad()
            pred = self.Model(images)
            loss = self.Criterion(pred,labels)
            running_loss+=loss.item()
            loss.backward()
            self.Optimizer.step()
            TempX = self.Softmax(pred).detach().cpu()
            PredictionProb.extend(TempX.numpy().tolist())
            Predictions.extend(torch.argmax(TempX,dim=1).numpy().tolist())
        trainResults = {
                        "Loss:":running_loss/len(self.TrainLoader),
                        "Accuracy:":accuracy_score(TrueLabels,Predictions),
                        "ROC-AUC:":roc_auc_score(TrueLabels,PredictionProb,multi_class="ovr")
                        }
        self.Results[-1].append(trainResults)

    def validate(self):
        self.Model.eval()
        running_loss = 0.0
        Predictions = []
        TrueLabels = []
        PredictionProb = []
        with torch.no_grad():
            for batch,data in enumerate(self.ValidateLoader):
                images,labels = data
                TrueLabels.extend(labels)
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                pred = self.Model(images)
                loss = self.Criterion(pred,labels)
                running_loss+=loss.item()
                TempX = self.Softmax(pred).detach().cpu()
                PredictionProb.extend(TempX.numpy().tolist())
                Predictions.extend(torch.argmax(TempX,dim=1).numpy().tolist())
                
        validateResults = {
                        "Loss:":running_loss/len(self.ValidateLoader),
                        "Accuracy:":accuracy_score(TrueLabels,Predictions),
                        "ROC-AUC:":roc_auc_score(TrueLabels,PredictionProb,multi_class="ovr")
                        }

        self.Results[-1].append(validateResults)

    def fit(self,epochs):
        for epoch in range(epochs):
            self.Results.append([])
            self.train()
            display("Train",epoch,epochs,self.Results[-1][-1])
            self.validate()
            display("Test ",epoch,epochs,self.Results[-1][-1])

            self.check(epoch,self.Results[-1][-1]["Loss:"])

In [None]:
if __name__ == '__main__':
    
    Start = perf_counter()
    print("Started...")
    Transform = transformations = transforms.Compose([
                                        transforms.Resize((256, 256)),
                                        transforms.ToTensor(),
                                     ])

    if not os.path.exists("./Results"):
        os.mkdir("./Results")
    
    X,Y = train.image_id.values,train.label.values
    X_Train,X_Test,Y_Train,Y_Test = train_test_split(X,Y,test_size=0.2,shuffle=True)
    TrainSet = RetrievalData(TRAIN_DIR, X_Train,Y_Train, Transform, Images_Per_Class)
    TrainLoader=DataLoader(TrainSet, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,pin_memory=True)
    
    ValidationSet = RetrievalData(TRAIN_DIR, X_Test,Y_Test, Transform, Images_Per_Class)
    ValidationLoader=DataLoader(ValidationSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,pin_memory=True)
    print("Minibatches in TrainLoader:",len(TrainLoader))
    print("Minibatches in ValidationLoader:",len(ValidationLoader))
    #Code for Training, Validation, Inference + Submission
    model = getModel("resnet152",len(labels))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.000001)
    Params = {
        "Model":model,
        "Criterion":criterion,
        "Optimizer":optimizer,
        "TrainLoader":TrainLoader,
        "ValidateLoader":ValidationLoader,
        "Labels":None,
        "Device":DEVICE,
        "Results":[],
        "CheckPoint":{
                    "Patience":2,
                    "SavePath":"./Results",
                    "earlyStopping":False
                    }
             }
    hashtag = Hashtag(Params)
    hashtag.fit(50)
    Finish = perf_counter()
    print("Ended.")
    print("Time Taken:",Finish-Start, "seconds")