In [1]:
from __future__ import print_function, division

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary
from torchmetrics.classification import Accuracy
from sklearn.metrics import confusion_matrix
import plotly.graph_objs as go

import sys
sys.path.append('../')

from modules.helpers import *
from modules.datasets import *
from modules.train_utils import train_model
from modules.dataloaders import *

In [2]:
import wandb


cfg = {
  "learning_rate": 0.01,
  "epochs": 15,
  "pretrained_resnet": True,
  "img_size" : 224,
  "n_classes" : 2,
  "label_type" : "gram_strain",
  "balanced_mode" : False,
  "expand_channels" : True,
}


In [3]:
data_dir = '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_c = cfg['n_classes']

In [4]:
dataloaders = {}
dataloaders['train'], dataloaders['val'], dataloaders['test'], dataset_sizes =  get_bacteria_dataloaders(cfg['img_size'], 32 , 10, label_type = cfg['label_type'] , balanced_mode = cfg['balanced_mode'],expand_channels = cfg['expand_channels'])

class_names = [x for x in range(0, n_c)]

if(n_c == 21):
    class_names = ['Acinetobacter','B subtilis','E. coli K12','S. aureus','E. coli (CCUG17620)','E. coli (NCTC13441)','E. coli (A2-39)','K. pneumoniae (A2-23)','S. aureus (CCUG35600)','E. coli (101)','E. coli (102)','E. coli (104)','K. pneumoniae (210)','K. pneumoniae (211)','K. pneumoniae (212)','K. pneumoniae (240)','Acinetobacter K12-21','Acinetobacter K48-42','Acinetobacter K55-13','Acinetobacter K57-06','Acinetobacter K71-71']

Dataset type train label type: gram_strain -> Loaded 375443 images
Dataset type val label type: gram_strain -> Loaded 48562 images
Dataset type test label type: gram_strain -> Loaded 48790 images


In [5]:
cfg['dataset_sizes'] = dataset_sizes

In [6]:
exp_name = "GramStrain - Resnet 18" + str(time.time())
wandb.init(project="bacteria-classification-gram-strain", name = exp_name, config = cfg,  entity="ramith")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mramith[0m (use `wandb login --relogin` to force relogin)


In [7]:
model_ft = models.resnet18(pretrained=wandb.config['pretrained_resnet'])

num_ftrs = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_ftrs, n_c)

model_ft = model_ft.to(device)

In [8]:
summary(model_ft, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [9]:
model_ft = model_ft.to(device)


criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr = cfg['learning_rate'], momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [10]:
wandb.watch(model_ft)

[]

In [None]:
 model_ft = train_model(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs =  cfg['epochs'], n_classes = cfg['n_classes'] , device = device, exp_name = exp_name, cfg = cfg)

Epoch 0/14
----------
train Loss: 0.1611 Acc: 0.9516
val Loss: 0.1328 Acc: 0.9592
[0.9949410557746887, 0.6086306571960449]

[0.9954017996788025, 0.661982536315918]

Epoch 1/14
----------
train Loss: 0.1325 Acc: 0.9558
val Loss: 0.1246 Acc: 0.9596
[0.9934828281402588, 0.6574302315711975]

[0.9928371906280518, 0.6862298250198364]

Epoch 2/14
----------
train Loss: 0.1211 Acc: 0.9585
val Loss: 0.1206 Acc: 0.9618
[0.9915174841880798, 0.6975088715553284]

[0.9913809299468994, 0.7192482948303223]

Epoch 3/14
----------
train Loss: 0.1138 Acc: 0.9608
val Loss: 0.1053 Acc: 0.9641
[0.9909053444862366, 0.722402811050415]

[0.9939922094345093, 0.7182456851005554]

Epoch 4/14
----------
train Loss: 0.1074 Acc: 0.9627
val Loss: 0.1041 Acc: 0.9650
[0.9909173846244812, 0.7399288415908813]

[0.9906650185585022, 0.7537029981613159]

Epoch 5/14
----------
train Loss: 0.1019 Acc: 0.9644
val Loss: 0.1150 Acc: 0.9611
[0.9907792806625366, 0.7556462287902832]

[0.9871060252189636, 0.7477220892906189]

Epoch 

In [None]:
from modules.test_utils import test_model

In [None]:
conf = test_model(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, n_classes = cfg['n_classes'] , device = device, cfg = cfg)