In [1]:
import os
import pickle
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from san_18_04_2 import Recognition_Module

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
device = 'cuda:0'
torch.cuda.empty_cache()

prov_Num, alpha_Num, ad_Num = 38, 25, 35

In [3]:
def save_file(history, path):
    with open(path, 'wb') as file:
        pickle.dump(history, file)

def load_file(path):
    with open(path, 'rb') as file:
        history = pickle.load(file)
    return history

In [4]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_path = self.data_frame.iloc[idx, 0]
        label = [x for x in self.data_frame.iloc[idx, 5:]] # license plate chracters' indices

        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

In [5]:
transform_list = [
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2),
    transforms.ColorJitter(contrast=0.2),
    transforms.ColorJitter(saturation=0.2),
    transforms.ColorJitter(hue=0.2),
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
    transforms.RandomAdjustSharpness(sharpness_factor=2),
    transforms.RandomInvert(p=0.5),
]

transform0 = transforms.Compose([
    transforms.Resize((480, 480)),
    transforms.ToTensor(),
])

transform1 = transforms.Compose([
    transforms.Resize((480, 480)),
    transforms.ToTensor(),
    transforms.RandomChoice(transform_list)
])

In [6]:
train_dataset = CustomDataset('../datasets/train.csv', transform = transform1)
valid_dataset = CustomDataset('../datasets/validate.csv', transform = transform0)

train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True, num_workers=50, persistent_workers=True, prefetch_factor=10)
validation_loader = DataLoader(valid_dataset, batch_size=100, shuffle=False, num_workers=10, persistent_workers=True, prefetch_factor=10)

In [7]:
def validate_model(model, validation_loader, loss_fn):
    validation_loss = 0.0
    ch_correct = 0
    li_correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for image, label in validation_loader:
            image = image.to(device)
            label = label.to(device).T

            attn, pred_label = model(image)
            loss = loss_fn(pred_label[0], label[0])
            for i in range(1, 7):
                loss += loss_fn(pred_label[i], label[i])

            validation_loss += loss.item()

            predictions = [torch.argmax(curr, dim=1) for curr in pred_label]
            pred_license = torch.stack(predictions, dim=1).to(device)
            ch_equal = (label.T == pred_license)
            ch_correct += torch.sum(ch_equal).item()
            li_equal = torch.all(ch_equal, dim=1)
            li_correct += torch.sum(li_equal).item()
            total += image.shape[0]

    return validation_loss, 100 * ch_correct / (7 * total), 100 * li_correct / total


In [8]:
def train_model(model, opt, sch, loss_fn, train_loader, validation_loader, lr, num_epochs):
    history = {'train_loss' : [], 'val_loss' : [], 'train_ch_acc' : [], 'train_li_acc' : [], 'val_ch_acc' : [], 'val_li_acc' : []}

    optimizer = opt(model.parameters(), lr=lr, momentum=0.9)
    scheduler = sch(optimizer, step_size=6, gamma=0.1)

    if os.path.exists('./model/checkpoint.pth'):
        history = load_file('./model/history.pkl')
        checkpoint = torch.load('./model/checkpoint.pth', map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    with tqdm(total=num_epochs * len(train_loader)) as train_bar:
        for epoch in range(num_epochs):
            if len(history['train_loss']) > epoch:
                train_bar.set_description(f"Resuming")
                train_bar.update(len(train_loader))
                print(f"Epoch {epoch+1} \tTraining Loss : {history['train_loss'][epoch]} \tValidation Loss : {history['val_loss'][epoch]}")
                print(f"Train Ch Accuracy : {history['train_ch_acc'][epoch]} \tTrain Li Accuracy : {history['train_li_acc'][epoch]}")
                print(f"Validation Ch Accuracy : {history['val_ch_acc'][epoch]} \tValidation Li Accuracy : {history['val_li_acc'][epoch]}")
                print()
                continue
            
            train_bar.set_description(f"Training epoch {epoch+1}")
            model.train()
            train_loss = 0.0
            ch_correct = 0
            li_correct = 0
            total = 0
            for image, label in train_loader:
                image = image.to(device)
                label = label.to(device).T

                attn, pred_label = model(image)
                loss = loss_fn(pred_label[0], label[0])
                for i in range(1, 7):
                    loss += loss_fn(pred_label[i], label[i])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                train_bar.update(1)

                predictions = [torch.argmax(curr, dim=1) for curr in pred_label]
                pred_license = torch.stack(predictions, dim=1).to(device)
                ch_equal = (label.T == pred_license)
                ch_correct += torch.sum(ch_equal).item()
                li_equal = torch.all(ch_equal, dim=1)
                li_correct += torch.sum(li_equal).item()
                total += image.shape[0]

            train_bar.set_description(f"Validating epoch {epoch+1}")
            val_loss, val_ch_acc, val_li_acc = validate_model(model, validation_loader, loss_fn)

            train_ch_acc = 100 * ch_correct / (7 * total)
            train_li_acc = 100 * li_correct / total

            print(f"Epoch {epoch+1} \tTraining Loss : {train_loss} \tValidation Loss : {val_loss}")
            print(f"Train Ch Accuracy : {train_ch_acc} \tTrain Li Accuracy : {train_li_acc}")
            print(f"Validation Ch Accuracy : {val_ch_acc} \tValidation Li Accuracy : {val_li_acc}")
            print()

            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['val_ch_acc'].append(val_ch_acc)
            history['val_li_acc'].append(val_li_acc)
            history['train_ch_acc'].append(train_ch_acc)
            history['train_li_acc'].append(train_li_acc)

            scheduler.step()

            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, './model/checkpoint.pth')
            save_file(history, './model/history.pkl')

            if epoch % 5 == 4 and epoch >= 9:
                torch.save(model.state_dict(), "./model/prediction_module_"+str(epoch+1)+".pth")

    return history

In [9]:
model = Recognition_Module()
model.to(device)

num_parameters = sum(p.numel() for p in model.parameters())
print("Number of parameters =", num_parameters)

num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters =", num_parameters)

Number of parameters = 16326064
Number of trainable parameters = 16326064


In [10]:
opt = optim.SGD
loss_fn = nn.CrossEntropyLoss().to(device)
sch = optim.lr_scheduler.StepLR

In [11]:
history = train_model(model, opt, sch, loss_fn, train_loader, validation_loader, 0.01, 50)

  0%|          | 0/60000 [00:00<?, ?it/s]

Epoch 1 	Training Loss : 15913.550233840942 	Validation Loss : 1104.1106877326965
Train Ch Accuracy : 43.84359608990225 	Train Li Accuracy : 0.7175179379484488
Validation Ch Accuracy : 74.20571428571428 	Validation Li Accuracy : 6.55

Epoch 2 	Training Loss : 3992.6096519231796 	Validation Loss : 393.530552983284
Train Ch Accuracy : 84.29972654078257 	Train Li Accuracy : 27.964865788311375
Validation Ch Accuracy : 90.72714285714285 	Validation Li Accuracy : 49.965

Epoch 3 	Training Loss : 1895.138988852501 	Validation Loss : 204.0238807797432
Train Ch Accuracy : 92.93315666224989 	Train Li Accuracy : 62.6465661641541
Validation Ch Accuracy : 95.75357142857143 	Validation Li Accuracy : 76.955

Epoch 4 	Training Loss : 1203.9721120893955 	Validation Loss : 152.7436990737915
Train Ch Accuracy : 95.70751173541244 	Train Li Accuracy : 77.2485978816137
Validation Ch Accuracy : 96.87214285714286 	Validation Li Accuracy : 83.29

Epoch 5 	Training Loss : 816.1090349704027 	Validation Loss : 11