In [1]:
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"
print("Device: ", device)

Device:  cuda


## Making Classification Data using SKLEARN


In [3]:
x, y = load_digits(return_X_y=True)

In [4]:
train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=0.8)

In [5]:
class ClassificationDataset(torch.utils.data.Dataset):
    def __init__(self, x, y ):
        super().__init__()
        self.x = x
        self.y = y
    
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, idx):
        return {
            "x": torch.tensor(self.x[idx], dtype=torch.float32, device=device),
            "y": torch.tensor(self.y[idx], dtype=torch.long, device=device)
        }
    

In [6]:
train_dataset = ClassificationDataset(train_x, train_y)
print("Train Dataset Size: ", len(train_dataset))
test_dataset = ClassificationDataset(test_x, test_y)
print("Test Dataset Size: ", len(test_dataset))


Train Dataset Size:  1437
Test Dataset Size:  360


In [10]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 32, shuffle=True)

## Creating Model using Module Container

In [10]:
class ClassificationModel(nn.Module):
    
    def __init__(self, inp_dim, out_dim):
        super(ClassificationModel, self).__init__()
        self.fc1 = nn.Linear(inp_dim, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 2048)
        self.out = nn.Linear(2048, out_dim)
        
    
    def forward(self, x):
        
        x = self.fc1(x)
        x = nn.functional.relu(x) # Activation
        x = self.fc2(x) 
        x = nn.functional.relu(x)
        x = self.fc3(x)
        x = nn.functional.relu(x)
        x = self.out(x)
        return nn.functional.softmax(x)
        

In [11]:
model = ClassificationModel(train_x.shape[1], 10)
model = model.to(device)

In [12]:
class Trainer(object):
    def __init__(self, optimizer, criteria, epochs=10, scheduler=None):
        """
        This class will train the model based on the 
        - optimizer: Optimizer algorithm (object)
        - criteria: It is the loss function which will be used like CrossEntropyLoss, MSELoss etc.
        """
        self.optimizer = optimizer
        self.epochs = epochs
        self.criteria = criteria
        self.scheduler = scheduler
        
    
    def train_one_step(self, x, y):
        """
        Training on Single Step
        - Predict the output
        - Optimize the parameters
        """
        self.optimizer.zero_grad() # Initialization of Gredients to 0
        y_hat = self.model(x)
        loss = self.criteria(y_hat, y)
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def train_one_epoch(self, data_loader):
        
        """
        This function will enable the epoch training and return the loss for the epoch
        """
        self.model.train() # Setting model in Training Mode
        total_loss = 0
        for idx, data in tqdm(enumerate(data_loader), total=len(data_loader)):
            loss = self.train_one_step(**data)
            total_loss += loss
        return total_loss / (idx + 1)
    
    def eval_one_epoch(self, data_loader):
        """
        This function will enable the epoch training and return the loss for the epoch
        """
        self.model.eval() # Setting model in Evaluation Mode
        total_loss = 0
        for idx, data in enumerate(data_loader):
            x, y = data["x"], data["y"]
            y_hat = self.model(x)
            loss = self.criteria(y_hat, y)
            total_loss += loss.item()
        return total_loss / (idx + 1)
    
    def fit(self, model, train_loader, valid_loader=None, scheduler=None, **kwargs):
        """
        This function will start the model training. 
        """
        self.model = model
        valid_loss = None
        for epoch in range(self.epochs):
            loss = self.train_one_epoch(train_loader)
            if valid_loader:
                valid_loss = self.eval_one_epoch(valid_loader)
            if hasattr(self, "sechduler") and self.sechduler != None:
                self.scheduler.step()
            tqdm.write(f"Epoch: {epoch}, Training Loss: {loss}, Validation Loss: {valid_loss}")
        return self.model
    

In [15]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=1e-5)
criteria = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5)
model_trainer = Trainer(optimizer = optimizer, criteria = criteria, scheduler=scheduler, epochs=100)

In [16]:
model_trainer.fit(model = model, train_loader=train_loader, valid_loader=test_loader)

  0%|          | 0/45 [00:00<?, ?it/s]

  return nn.functional.softmax(x)


Epoch: 0, Training Loss: 1.585699462890625, Validation Loss: 1.5118924677371979


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 1, Training Loss: 1.5723485363854302, Validation Loss: 1.60806671778361


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 2, Training Loss: 1.5482478300730387, Validation Loss: 1.6596690714359283


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 3, Training Loss: 1.5630925337473551, Validation Loss: 1.6317171156406403


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 4, Training Loss: 1.5659170389175414, Validation Loss: 1.5739780167738597


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 5, Training Loss: 1.5532565673192342, Validation Loss: 1.649780531724294


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 6, Training Loss: 1.5544614129596286, Validation Loss: 1.7631211578845978


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 7, Training Loss: 1.657957665125529, Validation Loss: 1.5388804574807484


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 8, Training Loss: 1.612093628777398, Validation Loss: 1.6055404841899872


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 9, Training Loss: 1.5876244359546237, Validation Loss: 1.6961935758590698


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 10, Training Loss: 1.6241831408606635, Validation Loss: 1.6606604258219402


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 11, Training Loss: 1.6281365129682752, Validation Loss: 1.6845606068770091


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 12, Training Loss: 1.6792842441134983, Validation Loss: 1.791037658850352


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 13, Training Loss: 1.7798898882336087, Validation Loss: 1.7421776453653972


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 14, Training Loss: 1.8187986585828992, Validation Loss: 1.9381452997525532


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 15, Training Loss: 1.9446799596150717, Validation Loss: 2.321473479270935


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 16, Training Loss: 2.2988716072506374, Validation Loss: 2.0805452366669974


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 17, Training Loss: 2.117645483546787, Validation Loss: 2.0854092140992484


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 18, Training Loss: 1.921753176053365, Validation Loss: 1.7215529481569927


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 19, Training Loss: 1.9689141935772365, Validation Loss: 2.2410224278767905


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 20, Training Loss: 2.233786922030979, Validation Loss: 2.251979927221934


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 21, Training Loss: 2.2692050509982638, Validation Loss: 2.2762538393338523


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 22, Training Loss: 2.231401475270589, Validation Loss: 2.195509115854899


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 23, Training Loss: 2.2030101087358265, Validation Loss: 2.2189628879229226


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 24, Training Loss: 2.0823801358540854, Validation Loss: 2.075828939676285


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 25, Training Loss: 2.1096179379357234, Validation Loss: 2.2606299916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 26, Training Loss: 2.250983142852783, Validation Loss: 2.2605371872584024


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 27, Training Loss: 2.273717260360718, Validation Loss: 2.3075049916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 28, Training Loss: 2.341562763849894, Validation Loss: 2.2918799916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 29, Training Loss: 2.3414909203847247, Validation Loss: 2.2918799916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 30, Training Loss: 2.341634602016873, Validation Loss: 2.3075049916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 31, Training Loss: 2.3414190822177465, Validation Loss: 2.2840674916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 32, Training Loss: 2.3414190822177465, Validation Loss: 2.2918799916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 33, Training Loss: 2.3414190822177465, Validation Loss: 2.2918799916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 34, Training Loss: 2.3414909203847247, Validation Loss: 2.2996924916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 35, Training Loss: 2.3414190822177465, Validation Loss: 2.2996924916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 36, Training Loss: 2.341634602016873, Validation Loss: 2.2762549916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 37, Training Loss: 2.3417064401838514, Validation Loss: 2.2996924916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 38, Training Loss: 2.3414909203847247, Validation Loss: 2.2996924916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 39, Training Loss: 2.3412754058837892, Validation Loss: 2.2918799916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 40, Training Loss: 2.3412754058837892, Validation Loss: 2.2918799916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 41, Training Loss: 2.3414190822177465, Validation Loss: 2.2996924916903176


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 42, Training Loss: 2.341347244050768, Validation Loss: 2.2762550016244254


  0%|          | 0/45 [00:00<?, ?it/s]

Epoch: 43, Training Loss: 2.3414909203847247, Validation Loss: 2.3075049916903176


KeyboardInterrupt: 