In [4]:
#!/usr/bin/env python
# coding: utf-8

import cv2
import numpy as np
import os
import sys
import pandas as pd
from random import shuffle
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, datasets, models
from PIL import Image
import nibabel as nib
import time
import copy
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def create_paths(datapath):
    #     Create paths to all nested images
    imagepaths = []
    for root, dirs, files in os.walk(datapath, topdown=False):
        for name in files:
            imagepaths.append(os.path.join(root, name))
    return imagepaths

def get_label(imagepath, csvpath):
    #     Get the diagnosis label for path
    table = pd.read_csv(csvpath)
    idpath = imagepath.split('/')[13]
    img_id = idpath[idpath.find('_I') + 2:-4]
    group = table.loc[table['Image Data ID'] == int(
        img_id)]["Group"].reset_index(drop=True)[0]
    group_to_label = {'CN': 0, 'MCI': 2, 'AD': 1}
    label = group_to_label[group]
    return label

class ADNI(Dataset):
    def __init__(self, datapath, csvpath, labels = [0, 1, 2], transform=None):
        """
        Args:
            datapath (string): Directory with all the images.
            csvpath (string): Path to CSV
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        all_imagepaths = create_paths(datapath)[:-1]
        self.csvpath = csvpath
        self.imagepaths = [path for path in tqdm(all_imagepaths) if get_label(path, csvpath) in labels] 
        self.transform = transform

    def __len__(self):
        #         Returns the length of the dataset
        return len(self.imagepaths)

    def __getitem__(self, idx):
        #         Returns a tuple of the image and its group/label
        imgsize = 64

        if torch.is_tensor(idx):
            idx = idx.tolist()
        imagepath = self.imagepaths[idx]
        label = get_label(imagepath, csvpath)

        
        #         create imgbatch with three different perspectives
        imgbatch = []
        
        imgdata = nib.load(imagepath).get_fdata()
        if self.transform:
            imgdata = self.transform(imgdata)
            
        imgdata1 = cv2.resize(imgdata[imgdata.shape[0]//2, :, :], (imgsize, imgsize))
        imgdata1 = torch.from_numpy(imgdata1)
        imgdata1 = torch.stack([imgdata1, imgdata1, imgdata1], 0)
        imgbatch.append(imgdata1.reshape(3, imgsize, imgsize))
        
        imgdata2 = cv2.resize(imgdata[:, imgdata.shape[0]//2, :], (imgsize, imgsize))
        imgdata2 = torch.from_numpy(imgdata2)
        imgdata2 = torch.stack([imgdata2, imgdata2, imgdata2], 0)
        imgbatch.append(imgdata1.reshape(3, imgsize, imgsize))
        
        imgdata3 = cv2.resize(imgdata[:, :, imgdata.shape[0]//2], (imgsize, imgsize))
        imgdata3 = torch.from_numpy(imgdata3)
        imgdata3 = torch.stack([imgdata3, imgdata3, imgdata3], 0)
        imgbatch.append(imgdata1.reshape(3, imgsize, imgsize))
        
        sample = (imgbatch, torch.tensor(label))
        return sample


datapath = r"/media/swang/Windows/Users/swang/Downloads/ADNI1_Complete_1Yr_1.5T"
csvpath = r"/media/swang/Windows/Users/swang/Downloads/ADNI1_Complete_1Yr_1.5T_7_08_2020.csv"
dataset = ADNI(datapath, csvpath, [0,1])

lengths = [
    int(len(dataset) * 0.8),
    int(len(dataset) * 0.1),
    int(len(dataset) * 0.1) + 1
]


trainset, valset, testset = random_split(dataset, lengths)
image_datasets = {'train': trainset, 'val': valset, 'test': testset}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
              for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}  


HBox(children=(FloatProgress(value=0.0, max=2294.0), HTML(value='')))




In [15]:
class MultiCNN(nn.Module):
    def __init__(self):
        super(MultiCNN, self).__init__()
        self.fc1 = nn.Linear(6144, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)
        resnet = models.resnet50(pretrained=True).to(device) 
        self.new_resnet1 = nn.Sequential(*list(resnet.children())[:-1])
        self.new_resnet2 = nn.Sequential(*list(resnet.children())[:-1])
        self.new_resnet3 = nn.Sequential(*list(resnet.children())[:-1])

    def forward(self, x_slices):

        x1 = x_slices[0].to(device, dtype=torch.float) 
        x1 = self.new_resnet1(x1)
        x1 = x1.view(-1, 2048)

        x2 = x_slices[1].to(device, dtype=torch.float) 
        x2 = self.new_resnet2(x2)
        x2 = x2.view(-1, 2048)

        x3 = x_slices[2].to(device, dtype=torch.float) 
        x3 = self.new_resnet3(x3)
        x3 = x3.view(-1, 2048)

        out = torch.cat((x1, x2, x3), dim=-1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)

        return out

In [16]:
train_loss = []
train_accuracy = []
val_loss = []
val_accuracy = []

    
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 100)


        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval() 

            running_loss = 0.0
            running_corrects = 0


            for inputs, labels in tqdm(dataloaders[phase], total = dataset_sizes[phase]//32+1):
                labels = labels.to(device)
            
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)


                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * 32
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            if phase == 'train':
                train_loss.append(epoch_loss)
                train_accuracy.append(epoch_acc)
                
            if phase == 'val':
                val_loss.append(epoch_loss)
                val_accuracy.append(epoch_acc)
                
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))


            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

In [19]:
model = MultiCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=50)

Epoch 0/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.6912 Acc: 0.5720


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7566 Acc: 0.5424
Epoch 1/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.6537 Acc: 0.6081


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7231 Acc: 0.5424
Epoch 2/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.5727 Acc: 0.6102


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7221 Acc: 0.5508
Epoch 3/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.4384 Acc: 0.7913


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6763 Acc: 0.7203
Epoch 4/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9d06a0fc20>
Traceback (most recent call last):
  File "/home/swang/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 962, in __del__
    self._shutdown_workers()
  File "/home/swang/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 942, in _shutdown_workers
    w.join()
  File "/home/swang/anaconda3/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9d06a0fc20>
Traceback (most recent call last):
  File "/home/swang/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 962, in __del__
    self._shutdown_workers()
  File "/home/swang/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 942, in _shutdown_worker


train Loss: 0.2809 Acc: 0.9449


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7561 Acc: 0.7458
Epoch 5/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.2309 Acc: 0.9216


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6933 Acc: 0.6525
Epoch 6/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.2886 Acc: 0.8856


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.8527 Acc: 0.7203
Epoch 7/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.2013 Acc: 0.9227


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6571 Acc: 0.7458
Epoch 8/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.1187 Acc: 0.9650


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6084 Acc: 0.7881
Epoch 9/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0836 Acc: 0.9873


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5619 Acc: 0.8136
Epoch 10/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0865 Acc: 0.9756


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7489 Acc: 0.7712
Epoch 11/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0642 Acc: 0.9905


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5739 Acc: 0.7542
Epoch 12/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0583 Acc: 0.9883


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6576 Acc: 0.7712
Epoch 13/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0504 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6099 Acc: 0.7797
Epoch 14/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0495 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5945 Acc: 0.7797
Epoch 15/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0500 Acc: 0.9894


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6981 Acc: 0.7627
Epoch 16/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0539 Acc: 0.9915


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6458 Acc: 0.7797
Epoch 17/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0513 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5489 Acc: 0.8136
Epoch 18/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0529 Acc: 0.9894


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6289 Acc: 0.7881
Epoch 19/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0464 Acc: 0.9968


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5735 Acc: 0.8051
Epoch 20/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0378 Acc: 0.9989


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6005 Acc: 0.7881
Epoch 21/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0482 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6186 Acc: 0.7881
Epoch 22/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0498 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5471 Acc: 0.7881
Epoch 23/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0489 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6365 Acc: 0.7966
Epoch 24/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0544 Acc: 0.9862


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6315 Acc: 0.7881
Epoch 25/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0408 Acc: 0.9968


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5938 Acc: 0.7627
Epoch 26/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0456 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6203 Acc: 0.7881
Epoch 27/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0412 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6853 Acc: 0.7881
Epoch 28/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0444 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6289 Acc: 0.7966
Epoch 29/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0451 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6581 Acc: 0.7881
Epoch 30/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0490 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.8623 Acc: 0.7542
Epoch 31/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0458 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6436 Acc: 0.7712
Epoch 32/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0453 Acc: 0.9968


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7322 Acc: 0.7712
Epoch 33/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0431 Acc: 0.9979


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6581 Acc: 0.8136
Epoch 34/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0477 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6302 Acc: 0.7797
Epoch 35/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0473 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6090 Acc: 0.7966
Epoch 36/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0418 Acc: 0.9979


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6762 Acc: 0.7797
Epoch 37/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0434 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5714 Acc: 0.8051
Epoch 38/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0483 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6349 Acc: 0.7797
Epoch 39/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0467 Acc: 0.9936


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6757 Acc: 0.7966
Epoch 40/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0386 Acc: 0.9968


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6249 Acc: 0.7966
Epoch 41/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0444 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5886 Acc: 0.7881
Epoch 42/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0471 Acc: 0.9915


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5941 Acc: 0.8136
Epoch 43/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0488 Acc: 0.9979


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6690 Acc: 0.7966
Epoch 44/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0399 Acc: 0.9968


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.7079 Acc: 0.7712
Epoch 45/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0469 Acc: 0.9926


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.5749 Acc: 0.8051
Epoch 46/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0445 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6026 Acc: 0.8051
Epoch 47/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0529 Acc: 0.9926


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6139 Acc: 0.7797
Epoch 48/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0435 Acc: 0.9947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6383 Acc: 0.7712
Epoch 49/49
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))


train Loss: 0.0438 Acc: 0.9958


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


val Loss: 0.6612 Acc: 0.7881
Training complete in 31m 55s
Best val Acc: 0.813559


In [1]:
plt.figure(figsize=(10,5))
plt.title("Train and Val Loss")
plt.plot(train_loss,label="Train")
plt.plot(val_loss,label="Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
plt.figure(figsize=(10,5))
plt.title("Train and Val Accuracy")
plt.plot(train_accuracy,label="Train")
plt.plot(val_accuracy,label="Val")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

NameError: name 'plt' is not defined