In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim

from memory_profiler import profile

print('Good imports')

Good imports


In [3]:
df = pd.read_csv('data.csv', index_col=False)


labels_unordered = np.array(df[['784']])
# Need labels used to be numbered correctly to call to_categorical()
labels = np.copy(labels_unordered)
labels = np.where(labels == 49, 10, labels)
labels = np.where(labels == 50, 11, labels)
labels = np.where(labels == 51, 12, labels)
labels = np.where(labels == 52, 13, labels)
labels = np.where(labels == 82, 14, labels)
labels = np.where(labels == 33, 15, labels)
labels = np.where(labels == 34, 16, labels)


# Drop label col from data
df.drop(columns=['784'], inplace=True)

In [3]:
# transform data to 28x28
data2D = []
tmp_data = df.values
for img in tqdm(tmp_data):
    data2D.append(img.reshape(1,28,28))
data2D = np.array(data2D)

# split data into train & validation sets
train_data, val_data, train_label, val_label = train_test_split(data2D, labels, test_size=0.1, random_state=42)

train_data = torch.from_numpy(train_data)
train_label = torch.from_numpy(train_label)
val_data = torch.from_numpy(val_data)
val_label = torch.from_numpy(val_label)

print(f"train data: {train_data.shape}, train label: {train_label.shape} val data: {val_data.shape} val label {val_label.shape}")

100%|██████████| 240520/240520 [00:00<00:00, 710389.87it/s]


train data: torch.Size([216468, 1, 28, 28]), train label: torch.Size([216468, 1]) val data: torch.Size([24052, 1, 28, 28]) val label torch.Size([24052, 1])


In [5]:
# input img shape = (num_imgs, 1, 28, 28)
# after first Conv2d = (num_imgs, 16, 28, 28)
# after first MaxPool2d layer = (num_imgs, 16, 14, 14)
# after second Conv2d = (num_imgs, 32, 14, 14)
# after second MaxPool2d layer = (num_imgs, 32, 7, 7)
# after flatten = (num_imgs, 32*7*7)
# i.e. input of first dense layer should be 32*7*7

class basicModel(nn.Module):
    def __init__(self, num_classes):
        super(basicModel, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 16, 5, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 5, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.2),
            nn.Flatten()
        )

        self.lin_layers = nn.Sequential(
            nn.Linear(32*7*7, 128),
            nn.Linear(128, 64),
            nn.Linear(64, num_classes),
        )

        # self.softmax = nn.LogSoftmax(num_classes)
    @profile
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.lin_layers(x)
        # x = self.softmax(x)

        return x

num_classes = len(np.unique(labels))
model = basicModel(num_classes)
loss_funk = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.01)


# To Fix

train model in batches to avoid running out of memory!!!

In [6]:
num_epoch
train_losses = []
val_losses = []

@profile
def train(epoch):
    model.train()
    tr_loss = 0
    data_train, label_train = Variable(train_data), Variable(train_label)
    data_val, label_val = Variable(val_data), Variable(val_label)
    
    optimizer.zero_grad()
    
    output_train = model(train_data)
    output_val = model(val_data)
    
    loss_train = loss_funk(output_train, train_label)
    loss_val = loss_funk(output_val, val_label)
    train_losses.append(loss_train)
    val_losses.append(loss_val)
    
    loss_train.backward()
    optimizer.step()
    tr_loss = loss_train.item()
    

    # computing the updated weights of all the model parameters
    loss_train.backward()
    optimizer.step()
    tr_loss = loss_train.item()
    
    print('\n=====================================')
    print(f'Epoch: {epoch+1}\t loss: {loss_val}')
    print('=====================================')
    