In [1]:
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import Accuracy, Precision, Recall, F1Score
import os
import json

In [None]:
with open("config.json", 'r') as data:
    config_file = json.load(data)

print(config_file["model_folder_path"])

In [3]:
transform = transforms.Compose([
    transforms.ToTensor()
])
train_set = torchvision.datasets.EMNIST(root="/data", 
                                        split="byclass", 
                                        train=True, 
                                        download=True, 
                                        transform=transform)

val_set = torchvision.datasets.EMNIST(root="/data",
                                      split="byclass",
                                      train=False,
                                      download=True,
                                      transform=transform)

In [None]:
train_image, train_label = train_set[800]
test_image, test_label = val_set[20]

plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(train_image.squeeze(), cmap='gray')
plt.title(f'Train Label: {train_label}')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Test Label: {test_label}')
plt.axis('off')

plt.tight_layout()
plt.show()

In [5]:
train_data_loader = DataLoader(train_set, batch_size=config_file["hyper_parameters"]["batch_size"], shuffle=True, num_workers=2, pin_memory=True)
val_data_loader = DataLoader(val_set, batch_size=config_file["hyper_parameters"]["batch_size"], shuffle=True, num_workers=2, pin_memory=True)

In [None]:
counter = 0
for batch in train_data_loader:
    image, label = batch
    print(image.shape)
    print(label.shape)
    print("\n")
    counter += 1
    if counter > 40:
        break


counter = 0
for batch in val_data_loader:
    image, label = batch
    print(image.shape)
    print(label.shape)
    print("\n")
    counter += 1
    if counter > 40:
        break

In [None]:
for batch_idx, (images, labels) in enumerate(val_data_loader):
    print(f"Batch {batch_idx + 1}")
    print(f"Images shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    break

In [8]:
class EMNISTClassifier(nn.Module):
    def __init__(self, in_channels=1, out_channels=64, kernel_size=3):
        super(EMNISTClassifier, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) #remember input/stride!!!
        self.relu = nn.ReLU()

        self.fc1 = nn.Linear(128 * 7 * 7, 256)

        self.dropout = nn.Dropout(0.5)

        self.fc2 = nn.Linear(256, 62)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EMNISTClassifier().to(device)
summary(model, input_size=(1, 28, 28))

In [10]:
accuracy = Accuracy(task="multiclass", num_classes=62).to(device)
precision = Precision(task="multiclass", num_classes=62, average='macro').to(device)
recall = Recall(task="multiclass", num_classes=62, average='macro').to(device)
f1 = F1Score(task="multiclass", num_classes=62, average='macro').to(device)

writer = SummaryWriter()


In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config_file["hyper_parameters"]["lr"])
num_epochs = config_file["hyper_parameters"]["num_epochs"]

#train
for epoch in range(num_epochs):
    model.train()
    epoch_loss, epoch_accuracy, epoch_precision, epoch_recall, epoch_f1 = 0, 0, 0, 0, 0
    progress_bar_train = tqdm(train_data_loader, desc=f'Epoch [{epoch+1}/{num_epochs}]', leave=True)

    for batch_idx, (images, labels) in enumerate(progress_bar_train):
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        predictions = model(images)
        loss = loss_fn(predictions, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(predictions, dim=1)
        batch_accuracy = accuracy(preds, labels).item()
        batch_precision = precision(preds, labels).item()
        batch_recall = recall(preds, labels).item()
        batch_f1 = f1(preds, labels).item()

        epoch_loss += loss.item()
        epoch_accuracy += batch_accuracy
        epoch_precision += batch_precision
        epoch_recall += batch_recall
        epoch_f1 += batch_f1

        progress_bar_train.set_postfix({
            'Batch Loss': loss.item(),
            'Accuracy': batch_accuracy,
            'Precision': batch_precision,
            'Recall': batch_recall,
            'F1': batch_f1
        })

    #eval
    model.eval()
    val_loss, val_accuracy, val_precision, val_recall, val_f1 = 0, 0, 0, 0, 0
    progress_bar_val = tqdm(val_data_loader, desc=f'Validation [{epoch+1}/{num_epochs}]', leave=True)

    with torch.no_grad():
        for images, labels in progress_bar_val:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            predictions = model(images)
            loss = loss_fn(predictions, labels)

            preds = torch.argmax(predictions, dim=1)
            batch_accuracy = accuracy(preds, labels).item()
            batch_precision = precision(preds, labels).item()
            batch_recall = recall(preds, labels).item()
            batch_f1 = f1(preds, labels).item()

            val_loss += loss.item()
            val_accuracy += batch_accuracy
            val_precision += batch_precision
            val_recall += batch_recall
            val_f1 += batch_f1

            progress_bar_val.set_postfix({
                'Val Loss': loss.item(),
                'Accuracy': batch_accuracy,
                'Precision': batch_precision,
                'Recall': batch_recall,
                'F1': batch_f1
            })
            progress_bar_val.refresh()

    avg_loss = epoch_loss / len(train_data_loader)
    avg_accuracy = epoch_accuracy / len(train_data_loader)
    avg_precision = epoch_precision / len(train_data_loader)
    avg_recall = epoch_recall / len(train_data_loader)
    avg_f1 = epoch_f1 / len(train_data_loader)

    avg_val_loss = val_loss / len(val_data_loader)
    avg_val_accuracy = val_accuracy / len(val_data_loader)
    avg_val_precision = val_precision / len(val_data_loader)
    avg_val_recall = val_recall / len(val_data_loader)
    avg_val_f1 = val_f1 / len(val_data_loader)

    writer.add_scalar('Loss/train', avg_loss, epoch)
    writer.add_scalar('Accuracy/train', avg_accuracy, epoch)
    writer.add_scalar('Precision/train', avg_precision, epoch)
    writer.add_scalar('Recall/train', avg_recall, epoch)
    writer.add_scalar('F1_Score/train', avg_f1, epoch)

    writer.add_scalar('Loss/val', avg_val_loss, epoch)
    writer.add_scalar('Accuracy/val', avg_val_accuracy, epoch)
    writer.add_scalar('Precision/val', avg_val_precision, epoch)
    writer.add_scalar('Recall/val', avg_val_recall, epoch)
    writer.add_scalar('F1_Score/val', avg_val_f1, epoch)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

writer.close()

In [12]:
model_folder = config_file["model_folder_path"]
if not os.path.exists(model_folder):
    os.mkdir(model_folder)

In [13]:
torch.save(model, f"{model_folder}/EMNIST.pt")