In [1]:
import warnings

warnings.filterwarnings('ignore')

import numpy as np
import cv2
import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset

import os
from torchvision.datasets import MNIST
from torchvision import transforms as tfs

from datetime import datetime
print(torch.cuda.is_available())

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [3]:
Train = pd.read_csv('data\\train.csv')
Test = pd.read_csv('data\\test.csv')

In [4]:
X = Train.drop('label', axis=1)
X = np.array(X).astype('float')
X = X / 255.0
y = Train.label

In [5]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.20, random_state=42)

X_train_t =  torch.tensor(X_train, dtype=torch.float32).reshape(-1,1,28,28)
y_train_t =  torch.tensor(y_train.values, dtype=torch.long )
X_val_t =  torch.tensor(X_val,  dtype=torch.float32).reshape(-1,1,28,28)
y_val_t =  torch.tensor(y_val.values,  dtype=torch.long )

In [6]:
class MyDataset(Dataset):
    def __init__(self, data, labels,  transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        img, label = self.data[index], self.labels[index]
        
        if self.transform:
            img = self.transform(img)
            
        return img, label

In [7]:
transform = tfs.Compose([
    tfs.RandomHorizontalFlip(p=0.8),
    tfs.RandomRotation(20),
    tfs.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    tfs.RandomApply([
        tfs.RandomPerspective(distortion_scale=0.5, p=0.5),
        tfs.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2))
    ], p=1.0) 
])

In [8]:
train_dataset_1 = MyDataset(X_train_t, y_train_t)
val_dataset_1 = MyDataset(X_val_t, y_val_t)
train_dataset_2 = MyDataset(X_train_t, y_train_t, transform=transform)
val_dataset_2 = MyDataset(X_val_t, y_val_t, transform=transform)

train_dataset = train_dataset_1 + train_dataset_2
val_dataset = val_dataset_1 + val_dataset_2


train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1024, shuffle=False)

In [9]:
from LeNet import LeNet

In [10]:
model = LeNet().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimiser = torch.optim.Adam(model.parameters())

loaders = {"train": train_dataloader, "valid": val_dataloader}

In [11]:
from tqdm.notebook import tqdm

In [13]:
max_epochs = 25
accuracy = {'train' : [], 'valid' : []}
train_losses = [] 
valid_losses = []
epoch_erly_stopping = 0
flag = False 
col_not_best = 0
last_loss = np.Inf
best_model = model
start_time = datetime.now()

for epoch in tqdm(range(max_epochs)):
    
    for k, dataloader in loaders.items():
        epochs_correct = 0
        epochs_all = 0
        for x_batch, y_batch in (pbar := tqdm(dataloader)):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            if k == 'train':
                model.train()
                optimiser.zero_grad()
                outp = model(x_batch)
            else:
                model.eval()
                with torch.no_grad():
                    outp = model(x_batch)
            _, pred_class = torch.max(outp, dim=1)
            correct = (pred_class == y_batch).sum()
            alls = len(x_batch)
            epochs_correct += correct.item()
            epochs_all += alls
            loss = criterion(outp, y_batch)
            if k == 'train':
                train_losses.append(loss.item())
                loss.backward()
                optimiser.step()
            else:
                valid_losses.append(loss.item())
        if k == 'valid':
            train_loss = np.average(train_losses)
            valid_loss = np.average(valid_losses)
            print(f"[{epoch:>3}/{max_epochs:>3}] loss_train: {train_loss:.5f} | loss_valid: {valid_loss:.5f}")
            
            if last_loss > valid_loss:
                col_not_best = 0
                epoch_erly_stopping = epoch
                torch.save(best_model, "checkpoint.pt")
                best_model = model
                last_loss = valid_loss
            else:
                if col_not_best >= 20:
                    print("Stop")
                    accuracy[k].append(epochs_correct/epochs_all)
                    flag = True
                    break
                else:
                    col_not_best += 1
        print(f"Loader: {k}. Accuracy: {epochs_correct/epochs_all}")
        accuracy[k].append(epochs_correct/epochs_all)
    if flag:
        break 
torch.save(best_model.state_dict(), 'best_mod.pth')
print(f'Program execution time: {datetime.now() - start_time}')