### Importing Necessary Base Libraries

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from datetime import datetime

### Importing the User Libraries

In [2]:
import Data
import Logger
import Trainer
import Model

### The Training Environment Variables (Users can change these please read the associated comment)

#### General device and storage settings

In [3]:
device = "cpu" #Device to which the user needs the model and data to load

PATH = "/home/rishi/facebook/notes/template/checkpoints/"  # Path to save the model & optimizer parameters during training

Complete_log = "./logs/Complete"+str(datetime.now())+".log" # Path to save the complete log of training
Summary_log = "./logs/Summary"+str(datetime.now())+".log"  # Path to save the Summary log of the epoch

#### DataLoader and other data transforms

In [4]:
transform = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor()]) #give the transforms to be made on the image
data_loader = Data.getData(transform=transform) # getting data with the dataloader
batch_size = Data.getBatchSize() #edit the batch size in Data.py

#### Train time variables and Hyperparameters

In [5]:
# the starting and ending range of the epoch
start = 0 #need not modify this it is for during reloading the models from train state and the function takes care
num_epochs = 5 # the number of epochs to run

learning_rate = 0.0001 #learning rate

model = Model.Model()
    
Criterion = nn.NLLLoss() # Choose the loss function 
Optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) #Choose the optimizer as well

#TO LOAD FROM CHECKPOINT UNCOMMENT THIS
#model, Optimizer, start, loss = Logger.continue_checkpoint() #check the doc string for use

#### Train time variables (not to edit)

In [6]:
loss_history = {'train' : [], 'validate' : []}
validation_accuracy = []
best_accuracy = [0,0,0]

### The main Training Loop

In [7]:
def main():
    
    global num_epochs, data_loader, batch_size, Criterion, Optimizer, loss_history, validation_accuracy, best_accuracy, device
    
    # creating and initiating the logs
    handle = open(Complete_log,"w+") 
    handle.close()
    handle = open(Summary_log,"w+")
    handle.close()

    
    for epoch in range(start, num_epochs):
        
        loss_per_epoch = {"train" : [], "validate" : []}
        accuracy = 0

        for phase in ['train','validate']:

            print("------------------------------IN "+phase.capitalize()+"------------------------------")
            length = 0
            
            for data in data_loader[phase] :

                if phase == 'train':

                    loss = Trainer.train(model, Optimizer, Criterion, data)
                    loss_per_epoch['train'].append(loss)
                                          
                
                if phase == 'validate':

                    loss, acc = Trainer.validate(model, Criterion, data)
                    loss_per_epoch['validate'].append(loss)
                    accuracy+=acc     
                    
                
                length +=  len(data[0])
                Logger.print_train_progress(epoch, length, Data.getLength(phase), loss, phase, (accuracy/Data.getLength(phase)) *100 )
                Logger.write_log(Complete_log, epoch, length, Data.getLength(phase), loss, phase, (accuracy/Data.getLength(phase)) *100)
                
        
        accuracy = (accuracy/Data.getLength('validate')) * 100
        validation_accuracy.append(accuracy)
        loss_history['train'].append( np.mean(loss_per_epoch['train']) )
        loss_history['validate'].append( np.mean(loss_per_epoch['validate']) )
                
        if all( accuracy >= x for x in best_accuracy) :
            
            print("accuracy : ", accuracy,"\t saving the model")
            best_accuracy.append(accuracy)
            best_accuracy = sorted( best_accuracy, reverse=True )[:3]
            
            Logger.create_checkpoint( accuracy=accuracy, epoch=epoch, loss=loss, optimizer_state=Optimizer.state_dict(), model_state=model.state_dict(), device=device, chk_path=PATH)
                          
        log = open(Summary_log,"a")
        print("==================================== EPOCH SUMMARY ====================================")
        print("EPOCH : {}\nTRAIN LOSS : {:.3f}\tVALID LOSS : {:.3f}\tACCURACY : {:.4f}".format(epoch, loss_history['train'][-1], loss_history['validate'][-1], accuracy))
        print("=======================================================================================")
        log.write("EPOCH : {}\nTRAIN LOSS : {:.3f}\tVALID LOSS : {:.3f}\tACCURACY : {:.4f}\n".format(epoch, loss_history['train'][-1], loss_history['validate'][-1], accuracy) )
        log.close()

In [None]:
if __name__ == '__main__' :
    main()
    print("top accuracies : ", best_accuracy)