In [1]:
from __future__ import print_function, division

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary
from torchmetrics.classification import Accuracy
from sklearn.metrics import confusion_matrix
import plotly.graph_objs as go

import sys
sys.path.append('../')

from modules.helpers import *
from modules.datasets import *
from modules.train_utils import train_model
from modules.dataloaders import *
from modules.test_utils import test_model

### Initialize wandb and load model

In [2]:
path = "/n/home12/ramith/FYP/bacteria-classification/results/GramStrain - Resnet 181645133407.8305595/latest_model_epoch-7.pth"
saved = torch.load(path)
cfg   = saved['cfg']

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

n_c = cfg['n_classes']

class_names = [x for x in range(0, n_c)]

if(n_c == 21):
    class_names = ['Acinetobacter','B subtilis','E. coli K12','S. aureus','E. coli (CCUG17620)','E. coli (NCTC13441)','E. coli (A2-39)','K. pneumoniae (A2-23)','S. aureus (CCUG35600)','E. coli (101)','E. coli (102)','E. coli (104)','K. pneumoniae (210)','K. pneumoniae (211)','K. pneumoniae (212)','K. pneumoniae (240)','Acinetobacter K12-21','Acinetobacter K48-42','Acinetobacter K55-13','Acinetobacter K57-06','Acinetobacter K71-71']
elif(n_c == 5):
    class_names = ['Acinetobacter', 'B. subtilis', 'E. coli', 'K. pneumoniae', 'S. aureus']


In [4]:
cfg

{'learning_rate': 0.01,
 'epochs': 15,
 'pretrained_resnet': True,
 'img_size': 224,
 'n_classes': 2,
 'label_type': 'gram_strain',
 'balanced_mode': False,
 'expand_channels': True,
 'dataset_sizes': {'train': 375424, 'val': 48544, 'test': 48768}}

In [5]:
model_ft = models.resnet18(pretrained=cfg['pretrained_resnet'])

num_ftrs = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_ftrs, n_c)

model_ft = model_ft.to(device)

In [6]:
model = model_ft
model.load_state_dict(saved['state_dict']);
model.eval();

### Initialize Class Specific Dataloaders

In [7]:
data_dir = '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria/'

In [48]:
dataloaders = {}

N = 63
_, _, _, dataset_sizes =  get_bacteria_dataloaders(cfg['img_size'], N , 10, label_type = cfg['label_type'], balanced_mode = False, expand_channels = cfg['expand_channels'])


for i in range(0, cfg['n_classes']):
    print("=====")
    dataloaders[str(i)], _ =  get_bacteria_eval_dataloaders(cfg['img_size'], N , 10, label_type = cfg['label_type'] ,expand_channels = cfg['expand_channels'], isolate_class = i)


Dataset type train label type: gram_strain -> Loaded 375443 images
Dataset type val label type: gram_strain -> Loaded 48562 images
Dataset type test label type: gram_strain -> Loaded 48790 images
=====
Dataset type test label type: gram_strain -> All files = 48790
Loaded 43520 images only from class 0
=====
Dataset type test label type: gram_strain -> All files = 48790
Loaded 5270 images only from class 1


In [49]:
dataset_sizes['test'] = dataset_sizes['test']/N

In [50]:
# for inputs, labels in dataloaders["1"]:
#     print(".", end = '')

#     inputs = inputs.to(device,dtype=torch.float)
#     labels = labels.to(device)

#     outputs = model(inputs)
#     _, preds = torch.max(outputs, 1)
    
#     print(preds)
#     print(torch.mode(preds,0)[0])
#     print(torch.mode(labels,0)[0])
    
#     break

In [51]:

criterion = nn.CrossEntropyLoss()

In [52]:
dataset_sizes

{'train': 375417, 'val': 48544, 'test': 774.0952380952381}

In [53]:
from modules.test_utils import test_model_in_groups

In [54]:
test_model_in_groups(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, n_classes = cfg['n_classes'] , device = device, cfg = cfg)

starting group testing..
New class eval - 0
New class eval - 1
test Loss: 6.3536 Acc: 0.9986
[1.0, 1.0]

testing complete in 3m 15s
Test Acc: 0.998585


[[690, 0], [0, 83]]

In [55]:
# for inputs, labels in dataloaders[str(0)]: #take a batch of data from each class
#     print(".", end = '')

#     inputs = inputs.to(device,dtype=torch.float)
#     labels = labels.to(device)

#     outputs = model(inputs)
#     _, preds = torch.max(outputs, 1)
#     loss = criterion(outputs, labels)

#     preds = torch.mode(preds, 0)[0]
#     preds = torch.reshape(preds, (-1,))
    
#     labels = torch.mode(labels, 0)[0]
#     labels = torch.reshape(labels, (-1,))
#     print(preds)
#     print(labels)
#     break

In [56]:
# g = torch.Tensor(1)