# Image Classification using CNN - pytorch

# Import libraries

In [1]:
#importing necessary libraries 
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 

import os 
import glob 
import numpy as np 
from skimage import io 
from torch.utils.data import Dataset, DataLoader 

from skimage import transform
from torchvision import transforms, utils
from torch.utils.data import random_split

import matplotlib.pyplot as plt

# In case you want to pull data from google drive into a google collab environment

In [None]:
from google.colab import files, drive 
drive.mount('/content/drive')

# Convolutional Neural Network Architecture 

In [None]:
#Computation will use GPU or CPU. If GPU available, then use or else use CPU 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class Net(nn.Module): 
    
    def __init__(self): 

        super(Net, self).__init__()

        # input -> conv1 -> relu -> pool1 -> conv2 -> relu -> pool2 -> flatten -> FC1 -> relu -> dropout -> FC2

        self.conv_1 = nn.Conv2d(1,32, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 
        self.linear_1 = nn.Linear(7 * 7 * 64, 128) # input of 7*7*64 is determined from the output of 2nd pooling layer
        self.linear_2 = nn.Linear(128, 10) # the output of 10 has to be correspond 
        self.dropout = nn.Dropout(p=0.5) #reduce overfitting
        self.relu = torch.nn.ReLU() #relu is used as activation to reduce non-linearity

    def forward(self,x): 
        x = self.conv_1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv_2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = x.reshape(x.size(0), -1) #flatten 
        x = self.linear_1(x) 
        x = self.relu(x)
        x = self.dropout(x)
        pred = self.linear_2(x)

        return F.log_softmax(pred)

net = Net().to(device)

# Customized MNIST Dataset loading 

In [None]:
class MNISTDataset(Dataset): 

    def __init__(self,dir,transform=None): 
        self.dir = dir 
        self.transform = transform

    def __len__(self):
        files = glob.glob(self.dir+'/*.jpg') #[:50]
        return len(files)

    def __getitem__(self,idx):
        
        if torch.is_tensor(idx): 
            idx = idx.tolist()

        all_instances = glob.glob(self.dir + '/*.jpg') #[:50] #list of file names 
        img_fname = os.path.join(self.dir,all_instances[idx]) #path to image instances 
        image = io.imread(img_fname) #convert to numpy array
        digit = int(self.dir.split('/')[-1].strip())
        label = np.array(digit)

        instance = {'image':image, 'label':label}

        if self.transform:
            instance = self.transform(instance)

        #so what did we get at the end: inside of dataloader, you have batches, each batch will have multiple images, each image will then be a dictionary with two key,value pairs - image with its ndarray and key with its digit value

        return instance

# Custom Transformations of Rescaling and convert to Tensor

In [None]:
# 1 transformation
class Rescale(object):
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        
        image, label = sample['image'], sample['label'] #from dictionary like structure for each image
        h, w = image.shape[-2:] 

        if isinstance(self.output_size, int): 
            
            if h > w:
                new_h, new_w = self.output_size*h/w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size*w/h

        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        new_image = transform.resize(image, (new_h, new_w)) #apply the new computed heights and width 

        return {'image': new_image, 'label':label} #this will return images with a consistent height and width throughout

#2 transformation
class ToTensor(object):

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        image = image.reshape((1,image.shape[0],image.shape[1])) #convert 2d to 3d
        return {'image':torch.from_numpy(image) ,'label': torch.from_numpy(label)} #convert image and label from np array into tensor 

# Dataset Loaders

In [None]:
#training/validation dataset loader
batch_size = 10
list_datasets = []

for i in range(10): # we have images inside of 10 different folder, hence 10
    
    cur_ds = MNISTDataset('/content/drive/My Drive/trainingset/'+str(i), transform= transforms.Compose([Rescale(28), ToTensor()])) 
    list_datasets.append(cur_ds)

dataset = torch.utils.data.ConcatDataset(list_datasets) #one dataset with all training set instances 

train_size = int(len(dataset)*0.7)
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset,[train_size, val_size])

train_dataloader = DataLoader(train_dataset,batch_size,shuffle=True,num_workers=1) 
val_dataloader = DataLoader(val_dataset,batch_size,shuffle=True,num_workers=1) 

# Display one training instance 

In [None]:
for batch_idx, batch in enumerate(train_dataloader): 
    inputs,targets = batch['image'].to(device), batch['label'].to(device)
    break

plt.imshow(inputs[0].numpy().squeeze(), cmap='gray')

# Training and Validation

In [None]:
epochs = 10
learning_rate = 0.001
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss() #since we have multiclass classification

#for train/val loss graph
train_loss_list=[]
val_loss_list=[]

for epoch in range(epochs):
    
    ############# TRAIN ###############
    net.train()
    running_loss=0.0

    for batch_idx, batch in enumerate(train_dataloader): 
        inputs,targets = batch['image'].to(device,dtype=torch.float), batch['label'].to(device,dtype=torch.long)

    # Training pass
        optimizer.zero_grad()
        
        output = net(inputs)
        loss = criterion(output, targets)
        
    #This is where the model learns by backpropagating
        loss.backward()
        
    #And optimizes its weights here
        optimizer.step()
        
        running_loss += loss.item()

    train_loss_list.append(running_loss/(batch_idx+1))  
    print("Epoch {} - Training loss: {}".format(epoch+1, running_loss/(batch_idx+1)))

    
    ########### TESTING over validation set #############
    net.eval()

    correct = [0.0]*10
    total = [0.0]*10
    val_acc=0.0
    val_loss=0.0
    final_val_acc=0.0

    with torch.no_grad():
        
        for batch_idx, batch in enumerate(val_dataloader): 
            
            inputs,targets = batch['image'].to(device,dtype=torch.float), batch['label'].to(device,dtype=torch.long)
            pred_outputs = net(inputs)

    
            val_loss += F.nll_loss(pred_outputs, targets, size_average=False).item()

            _,pred_targets = torch.max(pred_outputs,1)
            c=(pred_targets == targets)

            for i in range(len(targets)):
                target = targets[i]
                correct[target] += c[i].item()
                total[target] +=1

    for i in range(10): 
        val_acc += 100*correct[i]/total[i]
    final_val_acc = val_acc/10
    val_loss_list.append(val_loss/(batch_idx+1)) 

    print("Epoch {} - Val Acc: {}".format(epoch+1, final_val_acc))

# Train and Validation loss graph 

In [None]:
fig=plt.figure(figsize=(20, 10))
plt.plot(np.arange(1, epochs+1), train_loss_list, label="Train loss")
plt.plot(np.arange(1, epochs+1), val_loss_list, label="Validation loss")
plt.xlabel('Loss')
plt.ylabel('Epochs')
plt.title("Loss Plots")
plt.legend(loc='upper right')
plt.show()

# Loading data + Testing over trained model above (only part where dataset API from torchvision is used)

In [None]:
from torchvision import datasets
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.Resize(28),transforms.ToTensor()]))
test_dataloader = torch.utils.data.DataLoader(mnist_testset, batch_size=32, shuffle=False)

net.eval()

results = list()
total = 0
for itr, (image, label) in enumerate(test_dataloader):
    pred = net(image)
    pred = torch.nn.functional.softmax(pred, dim=1)

    for i, p in enumerate(pred):
        if label[i] == torch.max(p.data, 0)[1]:
            total = total + 1
            results.append((image, torch.max(p.data, 0)[1]))

test_accuracy = total / (itr + 1)
print('Test accuracy {:.8f}'.format(test_accuracy))

# END