In [None]:
from __future__ import print_function, division

import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary
from torchmetrics.classification import Accuracy
from sklearn.metrics import confusion_matrix
import plotly.graph_objs as go

import sys
sys.path.append('../')

from modules.helpers import *
from modules.datasets import *
from modules.train_utils import train_model
from modules.dataloaders import *

In [None]:
import wandb


cfg = {
  "learning_rate": 0.01,
  "epochs": 15,
  "pretrained_resnet": True,
  "img_size" : 224,
  "n_classes" : 2,
  "label_type" : "antibiotic_resistant",
  "balanced_mode" : False,
  "expand_channels" : True,
}


In [None]:
data_dir = '/n/holyscratch01/wadduwage_lab/D2NN_QPM_classification/datasets/bacteria/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_c = cfg['n_classes']

In [None]:
dataloaders = {}
dataloaders['train'], dataloaders['val'], dataloaders['test'], dataset_sizes =  get_bacteria_dataloaders(cfg['img_size'], 32 , 10, label_type = cfg['label_type'] , balanced_mode = cfg['balanced_mode'],expand_channels = cfg['expand_channels'])

class_names = [x for x in range(0, n_c)]

if(n_c == 21):
    class_names = ['Acinetobacter','B subtilis','E. coli K12','S. aureus','E. coli (CCUG17620)','E. coli (NCTC13441)','E. coli (A2-39)','K. pneumoniae (A2-23)','S. aureus (CCUG35600)','E. coli (101)','E. coli (102)','E. coli (104)','K. pneumoniae (210)','K. pneumoniae (211)','K. pneumoniae (212)','K. pneumoniae (240)','Acinetobacter K12-21','Acinetobacter K48-42','Acinetobacter K55-13','Acinetobacter K57-06','Acinetobacter K71-71']

In [None]:
cfg['dataset_sizes'] = dataset_sizes

In [None]:
exp_name = "ARP - Resnet 18" + str(time.time())
wandb.init(project="antibiotic-resistance-prediction", name = exp_name, config = cfg,  entity="ramith")

In [None]:
model_ft = models.resnet18(pretrained=wandb.config['pretrained_resnet'])

num_ftrs = model_ft.fc.in_features

model_ft.fc = nn.Linear(num_ftrs, n_c)

model_ft = model_ft.to(device)

In [None]:
summary(model_ft, (3, 224, 224))

In [None]:
model_ft = model_ft.to(device)


criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr = cfg['learning_rate'], momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
wandb.watch(model_ft)

In [None]:
 model_ft = train_model(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs =  cfg['epochs'], n_classes = cfg['n_classes'] , device = device, exp_name = exp_name, cfg = cfg)

In [None]:
from modules.test_utils import test_model

In [None]:
conf = test_model(model_ft, [dataloaders, dataset_sizes, class_names] , criterion, n_classes = cfg['n_classes'] , device = device, cfg = cfg)