<a href="https://colab.research.google.com/github/taravatp/Multi_Spectral_Image_Segmentation/blob/main/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install import_ipynb

In [None]:
!pip install ml_collections

In [None]:
cd /content/drive/MyDrive/Vision_Impulse_Task

/content/drive/MyDrive/Vision_Impulse_Task


In [None]:
import torch
import torch.nn as nn
import import_ipynb

import matplotlib.pyplot as plt
import numpy as np
import cv2
import time
import pickle
from sklearn.metrics import cohen_kappa_score, accuracy_score, f1_score

In [None]:
import vit_seg_configs as configs
from transunet import Trans_Unet
from MSI_dataset import MSI_data
from Unet import unet
from Efficient_Network import ENet

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Creating Test Dataloader

In [None]:
BATCHSIZE = 16
test_data = MSI_data(flag='test')
test_dataloader = torch.utils.data.DataLoader(test_data,batch_size=BATCHSIZE,shuffle=True)

# Evaluation metrics

In [None]:
def evaluate(ground_truth,predicted_labels):

  num_classes = 3
  class_labels = list(range(num_classes))

  # Flatten the tensors for calculation
  ground_truth = ground_truth.view(-1).cpu()

  predicted_labels = torch.softmax(predicted_labels, dim=1)
  _ , predicted_labels = torch.max(predicted_labels, dim=1)
  predicted_labels = predicted_labels.view(-1).cpu()

  overall_accuracy = accuracy_score(ground_truth, predicted_labels)
  f1 = f1_score(ground_truth, predicted_labels, labels=class_labels, average='weighted')
  kappa = cohen_kappa_score(ground_truth, predicted_labels, labels=class_labels)

  confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.uint64)
  for t, p in zip(ground_truth, predicted_labels):
      confusion_matrix[int(t), int(p)] += 1
  intersection = np.diag(confusion_matrix)
  union = (confusion_matrix.sum(1) + confusion_matrix.sum(0) - intersection)
  iou = intersection / union.astype(np.float32)
  miou = np.mean(iou)

  return overall_accuracy, f1, kappa, miou

# Loading models

In [None]:
unet = unet().to(device)
enet = ENet(num_classes=3, in_channels=12).to(device)

CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-L_16': configs.get_r50_l16_config(),
    'testing': configs.get_testing(),
}

config = CONFIGS['ViT-B_16']
simple_transunet = Trans_Unet(config).to(device)
finetuned_transunet = Trans_Unet(config).to(device)
adversarial_transunet = Trans_Unet(config).to(device)

In [None]:
unet.load_state_dict(torch.load('models/unet_batch329.pth'))
enet.load_state_dict(torch.load('models/unet_batch329.pth'))
simple_transunet.load_state_dict(torch.load('models/transUNET9.pth'))
finetuned_transunet.load_state_dict(torch.load('models/transunet_finetuning9.pth'))
adversarial_transunet.load_state_dict(torch.load('models/TranUNet_GAN_9.pth'))

<All keys matched successfully>

# Test function

In [None]:
def test(model,test_dataloader):

  total_accuracy = 0
  total_f1 = 0
  total_koppa = 0
  total_miou = 0


  # model.eval()
  for iter,batch in enumerate(test_dataloader):

    input_image,target_image = batch[0].to(device), batch[1].to(device)
    target_image = torch.squeeze(target_image).long()
    prediction = model(input_image)

    accuracy, f1, kappa, miou = evaluate(target_image,prediction)
    total_accuracy += accuracy
    total_f1 += f1
    total_koppa += kappa
    total_miou += miou

  total_accuracy = total_accuracy/len(test_dataloader)
  total_f1 = total_f1/len(test_dataloader)
  total_koppa = total_koppa/len(test_dataloader)
  total_miou = total_miou/len(test_dataloader)

  return total_accuracy, total_f1, total_koppa, total_miou

# Testing UNet

In [None]:
accuracy, f1, koppa, miou = test(unet,test_dataloader)
print('accuracy:',accuracy)
print('f1 score:',f1)
print('Koppa',koppa)
print('MIOU:',miou)

accuracy: 0.8898595645103925
f1 score: 0.8898787464765322
Koppa 0.8030659069064573
MIOU: 0.7584027191968032


# Testing ENet

In [None]:
accuracy, f1, koppa, miou = test(model,test_dataloader)
print('accuracy:',accuracy)
print('f1 score:',f1)
print('Koppa',koppa)
print('MIOU:',miou)

accuracy: 0.7861962042914497
f1 score: 0.7908785612883183
Koppa 0.6369162405945048
MIOU: 0.6143507476956827


# Testing simple TransUNet

In [None]:
accuracy, f1, koppa, miou = test(simple_transunet,test_dataloader)
print('accuracy:',accuracy)
print('f1 score:',f1)
print('Koppa',koppa)
print('MIOU:',miou)

accuracy: 0.8895369164737654
f1 score: 0.8900488022009336
Koppa 0.8012370315359292
MIOU: 0.7609362656663072


# Testing finetuned TransUNet

In [None]:
accuracy, f1, koppa, miou = test(finetuned_transunet,test_dataloader)
print('accuracy:',accuracy)
print('f1 score:',f1)
print('Koppa',koppa)
print('MIOU:',miou)

accuracy: 0.8915312967182677
f1 score: 0.8931183696295653
Koppa 0.8083758782969621
MIOU: 0.7622924716658851


# Testing adversarial TransUNet

In [None]:
accuracy, f1, koppa, miou = test(adversarial_transunet,test_dataloader)
print('accuracy:',accuracy)
print('f1 score:',f1)
print('Koppa',koppa)
print('MIOU:',miou)

accuracy: 0.8708632198380835
f1 score: 0.8732680313092307
Koppa 0.7720907281791463
MIOU: 0.7300969604235625


#Testig model trained on RGB channels

In [None]:
accuracy, f1, koppa, miou = test(model,test_dataloader)
print('accuracy:',accuracy)
print('f1 score:',f1)
print('Koppa',koppa)
print('MIOU:',miou)

accuracy: 0.8922978624885466
f1 score: 0.8919412523166158
Koppa 0.8038286920365878
MIOU: 0.7590650291372357
