### GenesisLab-Assignment
CIFAR-10 classification using FishNet  
  
Written code file is based on 
- [2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network](https://github.com/sapjunior/chulacv2018/blob/master/Lab%206%20-%20Convolutional%20Neural%20Network.ipynb)
- [FishNet](https://github.com/kevin-ssy/FishNet/tree/b968f0244827e11201471edd8a979bd85027b991)
- [Deep Learning Zero to All(Pytorch) 10-6 ResNet for cifar10][1]
[1]:https://github.com/deeplearningzerotoall/PyTorch/blob/master/lab-10_6_2_Advance-CNN(ResNet_cifar10).ipynb
- [Pytorch tutorial - TRAINING A CLASSIFIER](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html?highlight=cifar)
  
To understand better, I took a look at
- [[PyTorch] 7. Custom Dataset](https://data-panic.tistory.com/21)
- [pytorch 직접 데이터 로더 만들고 이미지 학습시키기](https://www.kaeee.de/2021/04/29/pytorch-%EB%8D%B0%EC%9D%B4%ED%84%B0-%EB%A1%9C%EB%8D%94-%EB%A7%8C%EB%93%A4%EA%B8%B0.html#dataloader%EC%9D%98-%ED%95%84%EC%9A%94%EC%84%B1)
- [Quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html?highlight=cifar)
- [FishNet: A Versatile Backbone for Image, Region, and Pixel Level Prediction](https://yckim.medium.com/%EC%A0%95%EB%A6%AC-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction-b86a493f114d)

#### Downloading needed things

In [1]:
# Install missing packages for Google Colaboratory
# !pip3 install torchvision

In [2]:
# Cloning FishNet (The paper wrote FishNet code repository at the abstract part)
!git clone https://github.com/kevin-ssy/FishNet.git

Cloning into 'FishNet'...
remote: Enumerating objects: 75, done.[K
remote: Total 75 (delta 0), reused 0 (delta 0), pack-reused 75[K
Unpacking objects: 100% (75/75), done.


#### Import libraries

In [3]:
# based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
# import needed libraries
import sys
import urllib
import tarfile
import pickle
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

#### Check python and pytorch version

In [4]:
print("python version: ", sys.version)
print("pytorch version: ", torch.__version__)

python version:  3.7.12 (default, Sep 10 2021, 00:21:48) 
[GCC 7.5.0]
pytorch version:  1.10.0+cu111


#### Function of reading pickle file

In [5]:
# Based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
# Make function of reading pickle file (cifar-10 has 5 data_batch files, one meta file and one test_batch file)
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

#### Set device option and seed number

In [6]:
# Based on "Deep Learning Zero to All" resource
# Set device option and seed number
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set random seed number to fix the result
torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)

#### Download the cifar-10 dataset and unzip the file

In [7]:
# Based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
# Download CIFAR-10 dataset from https://www.cs.toronto.edu/~kriz/cifar.html
urllib.request.urlretrieve ('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 'cifar-10-python.tar.gz')

# Unzip the file
with tarfile.open('cifar-10-python.tar.gz') as tar:
    tar.extractall()
    
# Show files in extract directory
!ls cifar-10-batches-py/

batches.meta  data_batch_2  data_batch_4  readme.html
data_batch_1  data_batch_3  data_batch_5  test_batch


#### Call the model (FishNet)

In [8]:
# Based on "FishNet" resource
# Import the model based on the GitHub explanation (has fishnet99, fishnet150, fishnet201)
from FishNet.models.net_factory import fishnet150

fishnet = fishnet150()

In [9]:
# print(fishnet)

#### Function of data transformation

In [10]:
# Based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
# Utilize the downloaded dataset
class CIFAR10Loader(data.Dataset):
    def __init__(self, cifarDatasetPath='cifar-10-batches-py/',train=True, transform=None):
      self.transform = transform
      
      # For train data files (data_batch_1~data_batch_5)
      if train:
        self.cifarImages = np.empty((0,3072), dtype=np.uint8)
        self.cifarLabels = np.empty((0,), dtype=np.uint8)

        for batchNo in range(1,6):
          dataDict = unpickle(cifarDatasetPath+'/data_batch_'+str(batchNo))
          self.cifarImages = np.vstack((self.cifarImages, dataDict[b'data']))
          self.cifarLabels = np.hstack((self.cifarLabels, dataDict[b'labels']))
      
      # For test data file
      else:
        dataDict = unpickle(cifarDatasetPath+'/test_batch')
        self.cifarImages = dataDict[b'data']
        self.cifarLabels = np.array(dataDict[b'labels'])
        
      # Transfrom from (x,3072) ==> (32,32,3,x)
      self.cifarImages = self.cifarImages.reshape(-1,3,32,32).transpose(2,3,1,0)

    # __getitem__은 데이터셋이 가지고있는 데이터를 리턴하는 기능을 합니다. ([PyTorch] 7. Custom Dataset resource)
    def __getitem__(self, idx):
        image = self.cifarImages[:,:,:,idx]
        label = self.cifarLabels[idx]
        
        if self.transform:
            image = self.transform(image)

        return image, label

    # cifar10의 경우 트레이닝, 혹은 테스트 셋의 전체 이미지의 갯수가 될 것이다.(pytorch 직접 데이터 로더 만들고 이미지 학습시키기 resource)
    def __len__(self):
        return self.cifarLabels.shape[0]

In [11]:
# Based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
def transformer(image):
    image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.406, 0.456), (0.229, 0.225, 0.224)),
        #transforms.RandomHorizontalFlip(p=0.5),
    ])(image)
    return image

In [12]:
# Based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
cifar10Train = CIFAR10Loader(transform = transformer)
cifar10Test = CIFAR10Loader(train = False, transform = transformer)
cifar10TrainLoader = data.DataLoader(dataset=cifar10Train, batch_size=128,num_workers=4,shuffle=True)
cifar10TestLoader = data.DataLoader(dataset=cifar10Train, batch_size=128,num_workers=4,shuffle=False)

  cpuset_checked))


#### model, optimization and loss function

In [13]:
# based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
# make network, loss function and optimizer
net = fishnet.to(device)
net.train()
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0001)

#### Train

In [14]:
# based on "2018 Computer Vision Class @ Chulalongkorn University Lab 6 - Convolutional Neural Network" resource
print('== Start Training ==')

bestLoss = float('Inf')
totalEpoch = 20
outputModelFile = 'bestCIFAR10.pth'

for epoch in range(0,totalEpoch):
  totalLoss = 0
  for batchIdx,(image,label) in enumerate(cifar10TrainLoader):
    image, label = image.to(device), label.to(device)
    
    optimizer.zero_grad()
    
    # forward + backward + optimize
    output = net(image)
    loss = criterion(output, label)

    loss.backward()
    optimizer.step()
    totalLoss+=loss.item()
    
  # Loss decrease, we should save the model
  totalLoss = totalLoss/len(cifar10TrainLoader)
  if totalLoss < bestLoss:
    print('Saving..')
    bestLoss = totalLoss
    state = {
        'model': net.state_dict(),
        'loss': bestLoss,
        'epoch': epoch,
    }
    torch.save(state, outputModelFile)
    
    
  print(epoch+1,'/',str(totalEpoch),'Batch Loss:',totalLoss)
  
print('== Training Finish ==')

== Start Training ==


  cpuset_checked))


Saving..
1 / 20 Batch Loss: 1.6521334151172882
Saving..
2 / 20 Batch Loss: 1.1977402046513375
Saving..
3 / 20 Batch Loss: 0.9823411555241441
Saving..
4 / 20 Batch Loss: 0.8959156546141486
Saving..
5 / 20 Batch Loss: 0.7285186133878615
Saving..
6 / 20 Batch Loss: 0.6114533246325715
Saving..
7 / 20 Batch Loss: 0.5144982699238126
Saving..
8 / 20 Batch Loss: 0.41253555137330616
Saving..
9 / 20 Batch Loss: 0.3337600011273723
Saving..
10 / 20 Batch Loss: 0.283700877348023
Saving..
11 / 20 Batch Loss: 0.2299967078525392
Saving..
12 / 20 Batch Loss: 0.19867395245663041
Saving..
13 / 20 Batch Loss: 0.18594511469726063
Saving..
14 / 20 Batch Loss: 0.17142801642265465
Saving..
15 / 20 Batch Loss: 0.16046736099759637
Saving..
16 / 20 Batch Loss: 0.13663208576590966
17 / 20 Batch Loss: 0.14646968798106894
Saving..
18 / 20 Batch Loss: 0.1318818741213635
Saving..
19 / 20 Batch Loss: 0.12874162203305975
Saving..
20 / 20 Batch Loss: 0.12701801150141623
== Training Finish ==


#### Test

In [15]:
# Based on "Pytorch tutorial - TRAINING A CLASSIFIER" and "Deep Learning Zero to All" resources
# Calculate test accuracy
correct = 0
total = 0

with torch.no_grad():
    for data in cifar10TestLoader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = fishnet(images)
        
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

  cpuset_checked))


Accuracy of the network on the 10000 test images: 97 %


In [16]:
# Based on "Pytorch tutorial - TRAINING A CLASSIFIER" resource
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in cifar10TestLoader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = fishnet(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print("Accuracy for class {:5s} is: {:.1f} %".format(classname, accuracy))

  cpuset_checked))


Accuracy for class plane is: 97.4 %
Accuracy for class car   is: 99.0 %
Accuracy for class bird  is: 96.4 %
Accuracy for class cat   is: 96.2 %
Accuracy for class deer  is: 94.7 %
Accuracy for class dog   is: 96.5 %
Accuracy for class frog  is: 97.7 %
Accuracy for class horse is: 97.7 %
Accuracy for class ship  is: 98.7 %
Accuracy for class truck is: 96.5 %
