In [45]:
def train_model(image_datasets, path_keys, model, criterion, optimizer, scheduler, num_epochs):

  print("---TRAINING---")
  dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size = 4, shuffle = True, num_workers = 4)
                for x in path_keys}

  best_model_wts = copy.deepcopy(model.state_dict()) 
  best_no_corrects= 0

  classes = [0,1,2]
  class_correct = list(0. for i in range(len(classes)))
  class_total = list(0. for i in range(len(classes)))
  c = 0

  # for validation
  classes2 = [0,1,2]
  class_correct2 = list(0. for i in range(len(classes)))
  class_total2 = list(0. for i in range(len(classes)))
  c2 = 0

  for epoch in range(num_epochs):
    # Set the model to the training mode for updating the weights using # the first portion of training images
    running_loss = 0
    model.train()
    for inputs, labels in dataloaders[path_keys[0]]: # iterate over data
      inputs = inputs.to(device)
      labels = labels.to(device) 
      optimizer.zero_grad()
      with torch.set_grad_enabled(True):
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1) 
        c = (predicted == labels).squeeze()
        loss = criterion(outputs, labels) 
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
  
        for i in range(len(labels)):
          label = labels[i]
          if len(c.size()) != 0:
            class_correct[label] += c[i].item()
          else:
            class_correct[label] += c
          class_total[label] += 1

  # Set the model to the evaluation mode for selecting the best network # based on the number of correctly classified validation images 
   
    model.eval()
    no_corrects = 0
    for inputs2, labels2 in dataloaders[path_keys[1]]: 
      inputs2 = inputs2.to(device)
      labels2 = labels2.to(device)
      with torch.set_grad_enabled(False):
        outputs2 = model(inputs2)
        _, predicted2 = torch.max(outputs2, 1)
        c2 = (predicted2 == labels2).squeeze()
        no_corrects += torch.sum(predicted2 == labels2.data)

        for i in range(len(labels2)):
          label2 = labels2[i]
          if len(c2.size()) != 0:
            class_correct2[label2] += c2[i].item()
          else:
            class_correct2[label] += c2
          class_total2[label2] += 1
    
    if no_corrects > best_no_corrects:
      best_no_corrects = no_corrects
      best_model_wts = copy.deepcopy(model.state_dict())
    
    scheduler.step()
    # Load the weights of the best network
  
  # train accuracy
  print()
  print("TRAIN ACCURACY")
  print()
  for i in range(3):
    if class_total[i] != 0 :
      print('Accuracy of %5s : %2d %%' % (
          classes[i], 100 * class_correct[i] / class_total[i]))
    else: 
      print("Class total is 0!")
  print("Overall accuracy is: " , (class_correct[0] +  class_correct[1] + class_correct[2]) / (class_total[0] + class_total[1] +class_total[2]))    
  
  # print("class_correct[0]", class_correct[0])
  # print("class_correct[1]", class_correct[1])
  # print("class_correct[2]", class_correct[2])

  # print("-------------------------")
    
  # print("class_total[0]", class_total[0])
  # print("class_total[1]", class_total[1])
  # print("class_total[2]", class_total[2])


  print("-------------------------")
  print("-------------------------")
 
  # validation accuracy
  print("VALIDATION ACCURACY")
  print()
  for i in range(3):
    if class_total2[i] != 0 :
      print('Accuracy of %5s : %2d %%' % (
          classes2[i], 100 * class_correct2[i] / class_total2[i]))
    else: 
      print("Class total is 0!")

  print("Overall accuracy is: " , (class_correct2[0] +  class_correct2[1] + class_correct2[2]) / (class_total2[0] + class_total2[1] +class_total2[2]))    
  
  # print("class_correct[0]", class_correct2[0])
  # print("class_correct[1]", class_correct2[1])
  # print("class_correct[2]", class_correct2[2])

  # print("-------------------------")
    
  # print("class_total[0]", class_total2[0])
  # print("class_total[1]", class_total2[1])
  # print("class_total[2]", class_total2[2])

  print()
  print()

  # calc_accuracy(class_total, class_correct) 
  model.load_state_dict(best_model_wts) 
  return model


In [61]:
def test(image_datasets, model, path_keys):
 
  c = 0
  inputs = 0
  outputs = 0
  predicted = 0
  class_total = 0
  class_correct = 0

  #Testing classification accuracy for individual classes.

  dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size = 4, shuffle = True, num_workers = 4)
                for x in path_keys}

  classes = [0,1,2]
  class_correct = list(0. for i in range(len(classes)))
  class_total = list(0. for i in range(len(classes)))

  model.eval()
  print("Eval active")
  with torch.no_grad():
    for inputs, labels in dataloaders[path_keys[2]]:
    
      inputs = inputs.to(device)
      labels = labels.to(device) 
      outputs = model(inputs)
      _, predicted = torch.max(outputs, 1)
      c = (predicted == labels).squeeze()
      # class_total, class_correct = check_labels(labels, c)
  
      for i in range(len(labels)):
        label = labels[i]
        if len(c.size()) != 0:
          class_correct[label] += c[i].item()
        else:
          class_correct[label] += c
        class_total[label] += 1

  print("---TEST ACCURACY---")
  for i in range(3):
    if class_total[i] != 0 :
      print('Accuracy of %5s : %2d %%' % (
          classes[i], 100 * class_correct[i] / class_total[i]))
    else: 
      print("Class total is 0!")
  print("Overall accuracy is: " , (class_correct[0] +  class_correct[1] + class_correct[2]) / (class_total[0] + class_total[1] +class_total[2]))    
  

  # calc_accuracy(class_total, class_correct)  

  # print("class_correct[0]", class_correct[0])
  # print("class_correct[1]", class_correct[1])
  # print("class_correct[2]", class_correct[2])

  # print("-------------------------")
    
  # print("class_total[0]", class_total[0])
  # print("class_total[1]", class_total[1])
  # print("class_total[2]", class_total[2])


In [47]:
def data_transforms(is_normalized, mean_arr, std_arr, path_keys):
  if is_normalized:
    return  {     path_keys[0]: transforms.Compose([
                  transforms.ToTensor(),
                  transforms.Normalize(mean = mean_arr[0], std = std_arr[0])
                  ]),
              
                  path_keys[1]: transforms.Compose([
                  transforms.ToTensor(),      
                  transforms.Normalize(mean = mean_arr[1], std = std_arr[1])
                  ]),
                  path_keys[2]: transforms.Compose([
                  transforms.ToTensor(),                                       
                  transforms.Normalize(mean = mean_arr[2], std = std_arr[2])
                  ])
                  }
  else: 
    return {      path_keys[0]: transforms.Compose([
                  transforms.ToTensor(),
           
                 
                  ]),
              
                  path_keys[1]: transforms.Compose([
                  transforms.ToTensor(),
                           
           
                  ]),
                  path_keys[2]: transforms.Compose([
                  transforms.ToTensor(),                                     
        
                  ])
                  }


In [48]:
from google.colab import drive
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [49]:
import torch 
import numpy
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import numpy as np
from skimage import io
import os
from __future__ import print_function, division
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
import time
import copy
from torch import Tensor, autograd
from skimage import color
import cv2
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import pdb

In [50]:
class_no = 3

num_epochs = 4

LEARNING_RATE = 0.1
path_dataset = "/content/drive/MyDrive/COMP448-HW3/dataset_valid_vers"


sub_path_balanced_train_img = "/balanced_train/" 
sub_path_balanced_valid_img = "/balanced_valid/"
sub_path_imbalanced_train_img = "/imbalanced_train/" 
sub_path_imbalanced_valid_img = "/imbalanced_valid/"
sub_path_test_img =  "/test/"

folder_path_balanced_train_img = path_dataset + sub_path_balanced_train_img 
folder_path_balanced_valid_img = path_dataset + sub_path_balanced_valid_img
folder_path_imbalanced_train_img = path_dataset + sub_path_imbalanced_train_img 
folder_path_imbalanced_valid_img = path_dataset + sub_path_imbalanced_valid_img
folder_path_test_img = path_dataset + sub_path_test_img


In [51]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_conv = models.alexnet(pretrained = True)

for param in model_conv.parameters():
  param.requires_grad = False

model_conv.classifier[6] = nn.Linear(4096, 3)
model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model_conv.parameters(), lr=LEARNING_RATE, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  

In [52]:
mean_arr=[0.485, 0.456, 0.406]
std_arr=[0.229, 0.224, 0.225]


In [53]:

balanced_path_keys = [sub_path_balanced_train_img, sub_path_balanced_valid_img, sub_path_test_img]
imbalanced_path_keys = [sub_path_imbalanced_train_img, sub_path_imbalanced_valid_img, sub_path_test_img]

normalized_balanced_image_datasets = {x: datasets.ImageFolder(path_dataset + x, data_transforms(True, mean_arr, std_arr, balanced_path_keys)[x]) for x in balanced_path_keys}
normalized_imbalanced_image_datasets = {x: datasets.ImageFolder(path_dataset + x, data_transforms(True, mean_arr, std_arr, imbalanced_path_keys)[x]) for x in imbalanced_path_keys}
not_normalized_balanced_image_datasets = {x: datasets.ImageFolder(path_dataset + x, data_transforms(False, mean_arr, std_arr, balanced_path_keys)[x]) for x in balanced_path_keys}
not_normalized_imbalanced_image_datasets = {x: datasets.ImageFolder(path_dataset + x, data_transforms(False, mean_arr, std_arr, imbalanced_path_keys)[x]) for x in imbalanced_path_keys}



In [67]:
## TRAINING 
# print("Normalized balanced model: ")
# normalized_balanced_model = train_model(normalized_balanced_image_datasets, balanced_path_keys, model_conv, criterion, optimizer, exp_lr_scheduler, 4)
# print()
# normalized_balanced_model = normalized_balanced_model.to(device)

# print("Normalized imbalanced model: ")
# normalized_imbalanced_model = train_model(normalized_imbalanced_image_datasets, imbalanced_path_keys, model_conv, criterion, optimizer, exp_lr_scheduler, 4)
# print()
# normalized_imbalanced_model = normalized_imbalanced_model.to(device)

# print("Not normalized balanced model: ")
# not_normalized_balanced_model = train_model(not_normalized_balanced_image_datasets, balanced_path_keys, model_conv, criterion, optimizer, exp_lr_scheduler, 4)
# print()
# not_normalized_balanced_model = not_normalized_balanced_model.to(device)

# #  below the case is not included
# print("Not normalized imbalanced model: ")
# not_normalized_imbalanced_model = train_model(not_normalized_imbalanced_image_datasets, imbalanced_path_keys, model_conv, criterion, optimizer, exp_lr_scheduler, 4)
# not_normalized_imbalanced_model = normalized_imbalanced_model.to(device)


Not normalized balanced model: 
---TRAINING---


  cpuset_checked))



TRAIN ACCURACY

Accuracy of     0 : 95 %
Accuracy of     1 : 84 %
Accuracy of     2 : 96 %
Overall accuracy is:  tensor(0.9202, device='cuda:0')
-------------------------
-------------------------
VALIDATION ACCURACY

Accuracy of     0 : 100 %
Accuracy of     1 : 88 %
Accuracy of     2 : 100 %
Overall accuracy is:  0.9607843137254902





In [68]:
# TESTING
print()
## since test data set is same in every dataset, normalized_balanced_image_datasets is used for all test results.
# print("Normalized balanced test: ")
# test(normalized_balanced_image_datasets,normalized_balanced_model, balanced_path_keys)
# print()
# print("Normalized imbalanced test: ")
# test(normalized_balanced_image_datasets,normalized_imbalanced_model, balanced_path_keys)
# print()
print("Not_normalized balanced test: ")
test(normalized_balanced_image_datasets,not_normalized_balanced_model, balanced_path_keys)


## below the case is not included
# test(normalized_balanced_image_datasets,not_normalized_imbalanced_model, balanced_path_keys)


Not_normalized balanced test: 
Eval active


  cpuset_checked))


---TEST ACCURACY---
Accuracy of     0 : 89 %
Accuracy of     1 : 92 %
Accuracy of     2 : 79 %
Overall accuracy is:  0.8819444444444444


In [None]:
# 