In [1]:
#mount the files to googledriveb

In [2]:
# create a CNN for handwriting recognition problem

import torch
import torch.nn as nn
import torch.nn.functional as F 

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

class Net(nn.Module):

  def __init__(self):
    super().__init__()

    # input_image: 1x28x28
    # define conv1: 6 filters 5x5
    self.conv1 = nn.Conv2d(1,6,5)
    # define conv2: 16 filters 5x5
    self.conv2 = nn.Conv2d(6,16,(5,5))
    # define a maxpooling layer : 2x2
    self.pool = nn.MaxPool2d(2,2)
    # define a sub fully connected feed forward netword after CNN: fc1 120 neurons, fc2 84 neurons, fc3 output
    self.fc1=nn.Linear(16*4*4,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self,x):
    # x-->conv1 --> relu -->pooling -->conv2 --> relu -->pooling --> fully connected
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1,self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.log_softmax(self.fc3(x))

    return x

  def num_flat_features(self,x):
    size = x.size()[1:]
    num_features = 1
    for s in size:
      num_features *= s
    return num_features  

net = Net().to(device)
print(net)


Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [3]:
import os
import glob
import numpy as np
from skimage import io
from torch.utils.data import Dataset, DataLoader

# create a customized Dataset
# override __init__, __len__, and __getitem__ methods

class MNISTDataset(Dataset):

  def __init__(self, dir, transform = None):
    self.dir = dir
    self.transform = transform

  def __len__(self):
    files = glob.glob(self.dir+'/*.jpg')[:500] # return all file names in a given folder as a list
    return len(files)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    
    all_files = glob.glob(self.dir+'/*.jpg')[:500]
    img_fname = os.path.join(self.dir,all_files[idx])
    image = io.imread(img_fname) # numpy array of that particular image

    digit = int(self.dir.split('/')[-1].strip())
    label = np.array(digit)

    instance = {'image':image, 'label':label}

    if self.transform:
      instance = self.transform(instance)

    return instance

In [4]:
# create a customized transformation: rescale, crop, totensor, etc...
from skimage import transform
from torchvision import transforms, utils

class Rescale(object):

  def __init__(self, output_size):
    assert isinstance(output_size, (int,tuple))
    self.output_size = output_size

  def __call__(self,sample):
    image, label = sample['image'], sample['label']

    h,w = image.shape[-2:]
    if isinstance(self.output_size, int):
      if h>w:
        new_h, new_w = self.output_size*h/w, self.output_size
      else:
        new_h, new_w = self.output_size, self.output_size*w/h
    else:
      new_h, new_w = self.output_size
    
    new_h, new_w = int(new_h), int(new_w)

    new_image = transform.resize(image, (new_h, new_w))
    return{'image':new_image, 'label':label}

In [5]:
class ToTensor(object):

  def __call__(self, sample):

    image, label = sample['image'], sample['label']
    image = image.reshape((1,image.shape[0], image.shape[1]))
    return {'image':torch.from_numpy(image), 'label':torch.from_numpy(label)}

In [6]:
# create training/validation dataset/dataloader

from torch.utils.data import random_split
from torchvision import transforms

batch_size = 48

list_datasets = []
for i in range(10):
  curr_ds = MNISTDataset('/content/drive/My Drive/trainingset/'+str(i),transform =transforms.Compose([Rescale(28), ToTensor()]))
  list_datasets.append(curr_ds)

dataset = torch.utils.data.ConcatDataset(list_datasets)
print(len(dataset))

train_size = int(len(dataset)*0.7)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=1)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=1)



5000


In [7]:
# training and validation
import torch.optim as optim

epochs = 5
learning_rate = 0.01
optimizer = optim.Adam(net.parameters(), lr= learning_rate, weight_decay = 1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):

  net.train()

  running_loss = 0.0
  for batch_idx, batch in enumerate(train_dataloader):
    inputs, targets = batch['image'].to(device, dtype = torch.float), batch['label'].to(device, dtype = torch.long)
    
    optimizer.zero_grad()
    predicted_output = net(inputs)
    loss = criterion(predicted_output, targets)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    if (batch_idx+1)%10 == 0:
      print('epoch: %d, batch: %d, training loss: %.3f'%(epoch+1, batch_idx+1, running_loss/10))
      running_loss = 0.0
    
  net.eval()

  correct = [0.0]*10
  total = [0.0]*10

  with torch.no_grad():
    for batch_idx, batch in enumerate(val_dataloader):
      images, labels = batch['image'].to(device, dtype = torch.float), batch['label'].to(device, dtype = torch.long)
      predicted_outputs = net(images)

      _,predicted_labels = torch.max(predicted_outputs,1)
      c = (predicted_labels == labels)

      for i in range(len(labels)):
        label = labels[i]
        correct[label] += c[i].item()
        total[label] +=1
  
  for i in range(10):
    print('\t Validation accuracy for digit: %d, %.2f'% (i, 100*correct[i]/total[i]))




epoch: 1, batch: 10, training loss: 2.172
epoch: 1, batch: 20, training loss: 1.207
epoch: 1, batch: 30, training loss: 0.746
epoch: 1, batch: 40, training loss: 0.532
epoch: 1, batch: 50, training loss: 0.473
epoch: 1, batch: 60, training loss: 0.455
epoch: 1, batch: 70, training loss: 0.445
	 Validation accuracy for digit: 0, 90.07
	 Validation accuracy for digit: 1, 97.28
	 Validation accuracy for digit: 2, 83.77
	 Validation accuracy for digit: 3, 87.97
	 Validation accuracy for digit: 4, 79.62
	 Validation accuracy for digit: 5, 92.26
	 Validation accuracy for digit: 6, 99.31
	 Validation accuracy for digit: 7, 95.39
	 Validation accuracy for digit: 8, 83.66
	 Validation accuracy for digit: 9, 90.58
epoch: 2, batch: 10, training loss: 0.165
epoch: 2, batch: 20, training loss: 0.311
epoch: 2, batch: 30, training loss: 0.340
epoch: 2, batch: 40, training loss: 0.256
epoch: 2, batch: 50, training loss: 0.207
epoch: 2, batch: 60, training loss: 0.221
epoch: 2, batch: 70, training loss