<a href="https://colab.research.google.com/github/Pheonix10101/PRCV_p_5/blob/main/new_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

"""
Project 5: Recognition using Deep Networks

Author: Samruddhi Raut

This file contain following tasks
Task 1
F: Test the network on new inputs

"""

import os
import torch
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.io import read_image
from matplotlib import pyplot as plt
import CNN


torch.manual_seed(42)


class NumDraw_data(Dataset):
    def __init__(self, annotations_file, img_dir, transform = None, target_transform = None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path).float()
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


def main():

    loaded_net = CNN.NeuralNetwork()
    loaded_net_state_dict = torch.load('samruddhi_neural.pt')
    loaded_net.load_state_dict(loaded_net_state_dict)

    hand_write_dateset = NumDraw_data(annotations_file = '/content/num_drawn_dataset.csv',
                                          img_dir = '/content/num_drawn_dataset')
    hand_write_dataloader = DataLoader(dataset = hand_write_dateset,
                                       batch_size = CNN.BATCH_SIZE_TEST,
                                       shuffle = False,
                                       num_workers = 4)

    # set model to evalution mode
    loaded_net.eval()
    test_loss = 0
    correct = 0
    imgs = []
    predictions = []
   
    with torch.no_grad(): # disable gradient calculation is useful for inference, backward() will not be called in testing
        for data, target in hand_write_dataloader:
            output = loaded_net(data)
            test_loss += F.cross_entropy(output, target, reduction = 'sum').item()
            pred = output.data.max(1, keepdim = True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
            imgs.append(data)
            predictions.append(pred)

    
    set_len = len(hand_write_dataloader.dataset)# noinspection PyTypeChecker
   
    if (set_len / CNN.BATCH_SIZE_TEST).is_integer():
        batch_num = set_len / CNN.BATCH_SIZE_TEST
    else:
        batch_num = int(set_len / CNN.BATCH_SIZE_TEST) + 1
    test_loss /= set_len

    print('\nTest over handwrite test set: Avg.loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, set_len, 100. * correct / set_len))

    fig = plt.figure()
    i = 0
    c = 0
    while i < batch_num:
        for j in range(CNN.BATCH_SIZE_TEST):
            if c >= set_len:
                i = batch_num
                break
            plt.subplot(4, 3, c + 1)
            plt.tight_layout()
            plt.imshow(imgs[i][j][0], cmap = 'gray', interpolation = 'none')
            plt.title("Prediction: %d" % predictions[i][j])
            plt.xticks([])
            plt.yticks([])
            c += 1
        i += 1
    fig.show()


if __name__ == '__main__':
    main()