In [1]:
from CT_COVID_Preprocess import CTCOVIDDataset
from model_tools import validation, train, load_model_checkpoint
from eval_tools import VGG19_CAM
from eval_tools import plot_train_val, test, display_FP_FN, get_confusion_matrix, plot_ROCAUC_curve, showCAM
from custom_model import CNN_COVIDCT
import torch
import torchvision
import torch.optim as optim
from monai.transforms import Compose, LoadPNG, AddChannel, ScaleIntensity, ToTensor, RandRotate, RandFlip, RandZoom, Resize, RandGaussianNoise
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from collections import OrderedDict
from torchvision import transforms

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
dirCOVID = 'C:/Users/rrsoo/AdvancedStuff/Medical_AI/Classifier/Images-processed-new/CT_COVID'
dirNonCOVID = 'C:/Users/rrsoo/AdvancedStuff/Medical_AI/Classifier/Images-processed-new/CT_NonCOVID'

In [4]:
train_transforms = transforms.Compose([
        LoadPNG(),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(degrees=15, prob=0.5),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        RandGaussianNoise(prob = 0.5),
        Resize(spatial_size=(224, 224)),
        ToTensor()
    ])
   

val_transforms = transforms.Compose([
    LoadPNG(),
    AddChannel(),
    ScaleIntensity(),
    ToTensor()
])

In [5]:
BATCH_SIZE = 32
ORIG_RES = False

train_ds = CTCOVIDDataset(dirCOVID, dirNonCOVID, transforms = train_transforms, data = 'train', orig_res = ORIG_RES)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

val_ds = CTCOVIDDataset(dirCOVID, dirNonCOVID, transforms = val_transforms, data = 'val', orig_res = ORIG_RES)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)

test_ds = CTCOVIDDataset(dirCOVID, dirNonCOVID, transforms = val_transforms, data = 'test', orig_res = ORIG_RES)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)

100%|██████████| 397/397 [00:00<00:00, 476.13it/s]
100%|██████████| 346/346 [00:01<00:00, 328.84it/s]
100%|██████████| 397/397 [00:00<00:00, 477.27it/s]
100%|██████████| 346/346 [00:01<00:00, 328.84it/s]
100%|██████████| 397/397 [00:00<00:00, 463.38it/s]
100%|██████████| 346/346 [00:01<00:00, 324.84it/s]


In [6]:
model_name = 'resnet34' #changeable
PRETRAINED = True #changeable

model = None #Don't touch

if model_name == 'resnet34':
    model = torchvision.models.resnet34(pretrained = PRETRAINED).to(device)
    
    if PRETRAINED:
        for param in model.parameters():
            param.requires_grad = False
            
    layers_resnet = nn.Sequential(OrderedDict([
                ('fc1', nn.Linear(512, 256)),
                ('activation1', nn.ReLU()),
                ('fc2', nn.Linear(256, 128)),
                ('activation2', nn.ReLU()),
                ('fc3', nn.Linear(128, 2)),
                ('out', nn.Sigmoid())
            ])).to(device)
    
    model.fc = layers_resnet

assert(model is not None)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
epochs = 40
optimizer = optim.Adam(model.parameters())
#scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-3, max_lr=1e-2, cycle_momentum=False)
criterion = nn.CrossEntropyLoss()
history = train(model, train_loader, val_loader, optimizer, criterion, epochs, device = device, scheduler = None)

100%|██████████| 19/19 [00:03<00:00,  4.97it/s]
100%|██████████| 3/3 [00:00<00:00,  8.90it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.57it/s]

Training Loss: 13.034444749355316
Training Accuracy: 0.5260504201680672
Validation Loss: 2.0720869302749634
Validation Accuracy: 0.527027027027027


100%|██████████| 19/19 [00:02<00:00,  6.67it/s]
100%|██████████| 3/3 [00:00<00:00,  8.87it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 12.409207582473755
Training Accuracy: 0.6033613445378151
Validation Loss: 1.969240963459015
Validation Accuracy: 0.6081081081081081


100%|██████████| 19/19 [00:02<00:00,  6.65it/s]
100%|██████████| 3/3 [00:00<00:00,  8.87it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 11.463574707508087
Training Accuracy: 0.719327731092437
Validation Loss: 1.7630200386047363
Validation Accuracy: 0.7297297297297297


100%|██████████| 19/19 [00:02<00:00,  6.64it/s]
100%|██████████| 3/3 [00:00<00:00,  8.85it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 10.535346299409866
Training Accuracy: 0.7563025210084033
Validation Loss: 1.5425951182842255
Validation Accuracy: 0.7837837837837838


100%|██████████| 19/19 [00:02<00:00,  6.63it/s]
100%|██████████| 3/3 [00:00<00:00,  8.74it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 10.198698043823242
Training Accuracy: 0.7663865546218488
Validation Loss: 1.62639981508255
Validation Accuracy: 0.7162162162162162


100%|██████████| 19/19 [00:02<00:00,  6.62it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 9.612675935029984
Training Accuracy: 0.8134453781512605
Validation Loss: 1.4094567596912384
Validation Accuracy: 0.8648648648648649


100%|██████████| 19/19 [00:02<00:00,  6.64it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 9.894443899393082
Training Accuracy: 0.7697478991596639
Validation Loss: 1.5214513540267944
Validation Accuracy: 0.7837837837837838


100%|██████████| 19/19 [00:02<00:00,  6.62it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 9.713624984025955
Training Accuracy: 0.7781512605042017
Validation Loss: 1.466100662946701
Validation Accuracy: 0.8378378378378378


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.80it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 9.375973969697952
Training Accuracy: 0.8134453781512605
Validation Loss: 1.380761295557022
Validation Accuracy: 0.8378378378378378


100%|██████████| 19/19 [00:02<00:00,  6.62it/s]
100%|██████████| 3/3 [00:00<00:00,  8.72it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 9.384544312953949
Training Accuracy: 0.8033613445378152
Validation Loss: 1.4797623753547668
Validation Accuracy: 0.8243243243243243


100%|██████████| 19/19 [00:02<00:00,  6.63it/s]
100%|██████████| 3/3 [00:00<00:00,  8.77it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 9.331850230693817
Training Accuracy: 0.8151260504201681
Validation Loss: 1.3515446186065674
Validation Accuracy: 0.8918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]
100%|██████████| 3/3 [00:00<00:00,  8.87it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 9.394055277109146
Training Accuracy: 0.8184873949579832
Validation Loss: 1.4652563631534576
Validation Accuracy: 0.8783783783783784


100%|██████████| 19/19 [00:02<00:00,  6.64it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 9.340785533189774
Training Accuracy: 0.8100840336134454
Validation Loss: 1.3532218039035797
Validation Accuracy: 0.8783783783783784


100%|██████████| 19/19 [00:02<00:00,  6.63it/s]
100%|██████████| 3/3 [00:00<00:00,  8.77it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.43it/s]

Training Loss: 8.748438984155655
Training Accuracy: 0.8504201680672269
Validation Loss: 1.2937892973423004
Validation Accuracy: 0.8513513513513513


100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
100%|██████████| 3/3 [00:00<00:00,  8.64it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 8.48038199543953
Training Accuracy: 0.865546218487395
Validation Loss: 1.31028813123703
Validation Accuracy: 0.8783783783783784


100%|██████████| 19/19 [00:02<00:00,  6.63it/s]
100%|██████████| 3/3 [00:00<00:00,  8.74it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 8.779923766851425
Training Accuracy: 0.853781512605042
Validation Loss: 1.529621034860611
Validation Accuracy: 0.7972972972972973


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.59it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 8.730936884880066
Training Accuracy: 0.8571428571428571
Validation Loss: 1.2050215303897858
Validation Accuracy: 0.8918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.77it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 8.798154294490814
Training Accuracy: 0.8487394957983193
Validation Loss: 1.3831575810909271
Validation Accuracy: 0.8918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.52it/s]

Training Loss: 8.511326730251312
Training Accuracy: 0.8621848739495799
Validation Loss: 1.575732409954071
Validation Accuracy: 0.7837837837837838


100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 8.308182924985886
Training Accuracy: 0.8840336134453781
Validation Loss: 1.1732455492019653
Validation Accuracy: 0.8918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.62it/s]
100%|██████████| 3/3 [00:00<00:00,  8.77it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 8.353963136672974
Training Accuracy: 0.8722689075630252
Validation Loss: 1.2199146747589111
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.62it/s]
100%|██████████| 3/3 [00:00<00:00,  8.74it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 8.472912341356277
Training Accuracy: 0.8672268907563025
Validation Loss: 1.2155371308326721
Validation Accuracy: 0.8918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.80it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 8.181353509426117
Training Accuracy: 0.8857142857142857
Validation Loss: 1.2429781556129456
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]
100%|██████████| 3/3 [00:00<00:00,  8.85it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.47it/s]

Training Loss: 7.961827427148819
Training Accuracy: 0.8907563025210085
Validation Loss: 1.2498685121536255
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
100%|██████████| 3/3 [00:00<00:00,  8.80it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.47it/s]

Training Loss: 8.100045323371887
Training Accuracy: 0.8873949579831932
Validation Loss: 1.2956326603889465
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
100%|██████████| 3/3 [00:00<00:00,  8.74it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 7.535557955503464
Training Accuracy: 0.9193277310924369
Validation Loss: 1.2510221898555756
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.62it/s]
100%|██████████| 3/3 [00:00<00:00,  8.72it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 7.9076297879219055
Training Accuracy: 0.8991596638655462
Validation Loss: 1.5566852986812592
Validation Accuracy: 0.8243243243243243


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.77it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 7.6365067064762115
Training Accuracy: 0.9176470588235294
Validation Loss: 1.3157033622264862
Validation Accuracy: 0.8918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.79it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 8.152100771665573
Training Accuracy: 0.880672268907563
Validation Loss: 1.2956292629241943
Validation Accuracy: 0.8783783783783784


100%|██████████| 19/19 [00:02<00:00,  6.58it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 8.989959001541138
Training Accuracy: 0.8252100840336134
Validation Loss: 1.2277001440525055
Validation Accuracy: 0.918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]
100%|██████████| 3/3 [00:00<00:00,  8.85it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 7.777432203292847
Training Accuracy: 0.9025210084033614
Validation Loss: 1.1589094996452332
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 7.524742692708969
Training Accuracy: 0.9176470588235294
Validation Loss: 1.2167813777923584
Validation Accuracy: 0.918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.58it/s]
100%|██████████| 3/3 [00:00<00:00,  8.72it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.43it/s]

Training Loss: 7.4331523180007935
Training Accuracy: 0.9226890756302522
Validation Loss: 1.2893367111682892
Validation Accuracy: 0.918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.39it/s]

Training Loss: 7.486196905374527
Training Accuracy: 0.9176470588235294
Validation Loss: 1.219035416841507
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.47it/s]

Training Loss: 7.867693454027176
Training Accuracy: 0.9025210084033614
Validation Loss: 1.3193175792694092
Validation Accuracy: 0.8783783783783784


100%|██████████| 19/19 [00:02<00:00,  6.61it/s]
100%|██████████| 3/3 [00:00<00:00,  8.75it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.60it/s]

Training Loss: 7.424097418785095
Training Accuracy: 0.9260504201680673
Validation Loss: 1.3125568628311157
Validation Accuracy: 0.9054054054054054


100%|██████████| 19/19 [00:02<00:00,  6.60it/s]
100%|██████████| 3/3 [00:00<00:00,  8.67it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.51it/s]

Training Loss: 7.49569171667099
Training Accuracy: 0.9176470588235294
Validation Loss: 1.1858757734298706
Validation Accuracy: 0.918918918918919


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]
100%|██████████| 3/3 [00:00<00:00,  8.77it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 7.060862749814987
Training Accuracy: 0.9428571428571428
Validation Loss: 1.2645426988601685
Validation Accuracy: 0.9324324324324325


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]
100%|██████████| 3/3 [00:00<00:00,  8.82it/s]
  5%|▌         | 1/19 [00:00<00:02,  6.55it/s]

Training Loss: 7.155577540397644
Training Accuracy: 0.9378151260504202
Validation Loss: 1.228758990764618
Validation Accuracy: 0.9324324324324325


100%|██████████| 19/19 [00:02<00:00,  6.58it/s]
100%|██████████| 3/3 [00:00<00:00,  8.69it/s]

Training Loss: 7.732251077890396
Training Accuracy: 0.9008403361344538
Validation Loss: 1.224666714668274
Validation Accuracy: 0.918918918918919





In [2]:
plot_train_val(history, epochs, (10, 10))

NameError: name 'plot_train_val' is not defined

In [1]:
model_dir = 'resnetforpaper.pth'

checkpoints = {
    'model' : model,
    'input_shape' : (224, 224, 3),
    'batch_size' : 32,
    'epochs' : 40,
    'train_acc' : 91,
    'val_acc' : 90,
    'state_dict' : model.state_dict()
}

torch.save(checkpoints, model_dir)

NameError: name 'model' is not defined