In [1]:
from __future__ import print_function, division

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary
from torchmetrics.classification import Accuracy
from sklearn.metrics import confusion_matrix
import plotly.graph_objs as go

import sys
sys.path.append('../')

from modules.helpers import *
from modules.datasets import *
from modules.train_utils import train_model
from modules.dataloaders import *
from modules.test_utils import test_model

### Initialize wandb and load model

In [2]:
path = "/n/home12/ramith/FYP/bacteria-classification/results/Strain Classification - Resnet 181645133846.9762254/latest_model_epoch-7.pth"
saved = torch.load(path)
cfg   = saved['cfg']

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

n_c = cfg['n_classes']

class_names = [x for x in range(0, n_c)]

if(n_c == 21):
    class_names = ['Acinetobacter','B subtilis','E. coli K12','S. aureus','E. coli (CCUG17620)','E. coli (NCTC13441)','E. coli (A2-39)','K. pneumoniae (A2-23)','S. aureus (CCUG35600)','E. coli (101)','E. coli (102)','E. coli (104)','K. pneumoniae (210)','K. pneumoniae (211)','K. pneumoniae (212)','K. pneumoniae (240)','Acinetobacter K12-21','Acinetobacter K48-42','Acinetobacter K55-13','Acinetobacter K57-06','Acinetobacter K71-71']
elif(n_c == 5):
    class_names = ['Acinetobacter', 'B. subtilis', 'E. coli', 'K. pneumoniae', 'S. aureus']


In [4]:
cfg

{'learning_rate': 0.01,
 'epochs': 15,
 'pretrained_resnet': True,
 'img_size': 224,
 'n_classes': 21,
 'label_type': 'class',
 'balanced_mode': False,
 'expand_channels': True,
 'dataset_sizes': {'train': 375424, 'val': 48544, 'test': 48768}}

In [5]:
model_ft = models.resnet18(pretrained=cfg['pretrained_resnet'])

num_ftrs = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_ftrs, n_c)

model_ft = model_ft.to(device)

In [6]:
model = model_ft
model.load_state_dict(saved['state_dict']);
model.eval();

### Initialize Class Specific Dataloaders

In [7]:
data_dir = '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria/'

In [8]:
dataloaders = {}

N = 31
_, _, _, dataset_sizes =  get_bacteria_dataloaders(cfg['img_size'], N , 10, label_type = cfg['label_type'], balanced_mode = False, expand_channels = cfg['expand_channels'])


for i in range(0, cfg['n_classes']):
    print("=====")
    dataloaders[str(i)], _ =  get_bacteria_eval_dataloaders(cfg['img_size'], N , 10, label_type = cfg['label_type'] ,expand_channels = cfg['expand_channels'], isolate_class = i)


Dataset type train label type: class -> Loaded 375443 images
Dataset type val label type: class -> Loaded 48562 images
Dataset type test label type: class -> Loaded 48790 images
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 1561 images only from class 0
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 1404 images only from class 1
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 1668 images only from class 2
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 1937 images only from class 3
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 3114 images only from class 4
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 1396 images only from class 5
=====
Dataset type test; dataloder will have label type: class -> All files = 48790
Loaded 2030 

In [9]:
dataset_sizes['test'] = dataset_sizes['test']/N

In [10]:
# for inputs, labels in dataloaders["1"]:
#     print(".", end = '')

#     inputs = inputs.to(device,dtype=torch.float)
#     labels = labels.to(device)

#     outputs = model(inputs)
#     _, preds = torch.max(outputs, 1)
    
#     print(preds)
#     print(torch.mode(preds,0)[0])
#     print(torch.mode(labels,0)[0])
    
#     break

In [11]:

criterion = nn.CrossEntropyLoss()

In [12]:
dataset_sizes

{'train': 375441, 'val': 48544, 'test': 1573.1612903225807}

In [13]:
from modules.test_utils import test_model_in_groups

In [14]:
test_model_in_groups(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, n_classes = cfg['n_classes'] , device = device, cfg = cfg)

starting group testing..
New class eval - 0
New class eval - 1
New class eval - 2
New class eval - 3
New class eval - 4
New class eval - 5
New class eval - 6
New class eval - 7
New class eval - 8
New class eval - 9
New class eval - 10
New class eval - 11
New class eval - 12
New class eval - 13
New class eval - 14
New class eval - 15
New class eval - 16
New class eval - 17
New class eval - 18
New class eval - 19
New class eval - 20
test Loss: 36.2081 Acc: 0.9351
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5076923370361328, 1.0, 1.0, 1.0, 0.7209302186965942, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.3333333432674408, 1.0, 1.0]

testing complete in 12m 49s
Test Acc: 0.935060


[[50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 53, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 62, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 100, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 45, 29, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 50, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 93, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 62, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 68, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 112, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 85, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 73, 0, 0, 0, 

In [15]:
# for inputs, labels in dataloaders[str(0)]: #take a batch of data from each class
#     print(".", end = '')

#     inputs = inputs.to(device,dtype=torch.float)
#     labels = labels.to(device)

#     outputs = model(inputs)
#     _, preds = torch.max(outputs, 1)
#     loss = criterion(outputs, labels)

#     preds = torch.mode(preds, 0)[0]
#     preds = torch.reshape(preds, (-1,))
    
#     labels = torch.mode(labels, 0)[0]
#     labels = torch.reshape(labels, (-1,))
#     print(preds)
#     print(labels)
#     break

In [33]:
# g = torch.Tensor(1)