In [1]:
from __future__ import print_function, division

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary
from torchmetrics.classification import Accuracy
from sklearn.metrics import confusion_matrix
import plotly.graph_objs as go

import sys
sys.path.append('../')

from modules.helpers import *
from modules.datasets import *
from modules.train_utils import train_model
from modules.dataloaders import *
from modules.test_utils import test_model

### Initialize wandb and load model

In [2]:
import wandb
wandb.init(id="183l05jm", project="bacteria-classification-whole-dataset", resume="must")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mramith[0m (use `wandb login --relogin` to force relogin)


In [3]:
path = "/n/home12/ramith/FYP/bacteria-classification/results/Strain Classification - Resnet 181645133846.9762254/latest_model_epoch-7.pth"
saved = torch.load(path)
cfg   = saved['cfg']

In [4]:
cfg

{'learning_rate': 0.01,
 'epochs': 15,
 'pretrained_resnet': True,
 'img_size': 224,
 'n_classes': 21,
 'label_type': 'class',
 'balanced_mode': False,
 'expand_channels': True,
 'dataset_sizes': {'train': 375424, 'val': 48544, 'test': 48768}}

### Initialize Dataloaders

In [5]:
data_dir = '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

n_c = cfg['n_classes']

In [6]:
dataloaders = {}
dataloaders['train'], dataloaders['val'], dataloaders['test'], dataset_sizes =  get_bacteria_dataloaders(cfg['img_size'], 32 , 10, label_type = cfg['label_type'] , balanced_mode = cfg['balanced_mode'],expand_channels = cfg['expand_channels'])

class_names = [x for x in range(0, n_c)]

if(n_c == 21):
    class_names = ['Acinetobacter','B subtilis','E. coli K12','S. aureus','E. coli (CCUG17620)','E. coli (NCTC13441)','E. coli (A2-39)','K. pneumoniae (A2-23)','S. aureus (CCUG35600)','E. coli (101)','E. coli (102)','E. coli (104)','K. pneumoniae (210)','K. pneumoniae (211)','K. pneumoniae (212)','K. pneumoniae (240)','Acinetobacter K12-21','Acinetobacter K48-42','Acinetobacter K55-13','Acinetobacter K57-06','Acinetobacter K71-71']
elif(n_c == 5):
    class_names = ['Acinetobacter', 'B. subtilis', 'E. coli', 'K. pneumoniae', 'S. aureus']

Dataset type train label type: class -> Loaded 375443 images
Dataset type val label type: class -> Loaded 48562 images
Dataset type test label type: class -> Loaded 48790 images


### Initialize Model architecture Resnet 18

In [7]:
model_ft = models.resnet18(pretrained=cfg['pretrained_resnet'])

num_ftrs = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_ftrs, n_c)

model_ft = model_ft.to(device)

In [8]:
criterion = nn.CrossEntropyLoss()

In [9]:
print("redda")

redda


In [10]:
model = model_ft
model.load_state_dict(saved['state_dict']);
model.eval();

### Test model and send info to wandb experiment

In [11]:
test_model(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, n_classes = cfg['n_classes'] , device = device, cfg = cfg)

starting testing..
.............................................................................................................................................................................................................................................................................................................................................................................................test Loss: 1.1752 Acc: 0.6354
[0.7160256505012512, 0.69017094373703, 0.7737094759941101, 0.7289623022079468, 0.7901670932769775, 0.4383954107761383, 0.27698373794555664, 0.610784649848938, 0.8562532663345337, 0.5271790027618408, 0.3209232985973358, 0.5983606576919556, 0.7607696652412415, 0.6983705759048462, 0.6498464345932007, 0.6790123581886292, 0.7816718816757202, 0.5928603410720825, 0.29914164543151855, 0.5554474592208862, 0.5869402885437012]


testing complete in 8m 12s
Test Acc: 0.635355


[[1117,
  88,
  11,
  2,
  14,
  32,
  22,
  3,
  0,
  15,
  26,
  3,
  44,
  7,
  12,
  54,
  28,
  9,
  26,
  12,
  19],
 [104,
  969,
  54,
  8,
  20,
  4,
  13,
  21,
  0,
  29,
  38,
  36,
  5,
  28,
  33,
  24,
  46,
  4,
  19,
  5,
  6],
 [7, 33, 1289, 1, 0, 0, 0, 42, 0, 2, 12, 15, 2, 98, 52, 5, 19, 0, 0, 2, 1],
 [4, 4, 0, 1412, 4, 53, 87, 0, 274, 2, 4, 1, 55, 0, 0, 56, 0, 58, 62, 9, 27],
 [53,
  20,
  4,
  0,
  2459,
  62,
  97,
  17,
  0,
  223,
  318,
  4,
  1,
  1,
  0,
  72,
  6,
  6,
  49,
  6,
  23],
 [20, 4, 1, 30, 46, 612, 489, 0, 1, 9, 26, 0, 3, 1, 0, 73, 0, 46, 225, 6, 51],
 [43,
  12,
  0,
  17,
  89,
  202,
  562,
  2,
  0,
  14,
  34,
  1,
  0,
  1,
  0,
  98,
  7,
  52,
  792,
  35,
  28],
 [3,
  26,
  51,
  2,
  11,
  1,
  1,
  1767,
  0,
  148,
  44,
  91,
  68,
  226,
  81,
  58,
  197,
  5,
  5,
  31,
  31],
 [0, 0, 0, 209, 1, 0, 10, 0, 1650, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 [24,
  16,
  3,
  1,
  164,
  6,
  52,
  96,
  0,
  1125,
  165,
  37,
  5,
  15,