In [1]:
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import DataLoader
import csv
import random
import numpy as np
import cv2
import mimetypes
import fnmatch
import seaborn as sns
from collections import Counter, defaultdict
import torchvision
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import time
from progress.bar import IncrementalBar
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import confusion_matrix
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_auc_score
import io
from pytorch_lightning.callbacks import Callback
from datetime import datetime, date, time
from PIL import Image
import itertools 
from sklearn.model_selection import train_test_split
import torchvision.models as models

from Sources.callbacks import plot_confusion_matrix
from Sources.CoAtNet import CoAtNet
from torch.nn import functional as F
from Sources.callbacks import get_true_classes
from Sources.callbacks import get_predicted_classes
from Sources.callbacks import get_classes_probs
from Sources.callbacks import callback
from Sources.data_tools import ImageDataset

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

labels_map = {
    "Benign": 0,
    "InSitu": 1,
    "Invasive": 2,
}

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = ImageDataset('../Data/burnasyan_Br.csv', 'train_paths.txt', transform = transform)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
train_features, train_labels = next(iter(train_dataloader))
train_features = train_features.to(device)
train_labels = train_labels.to(device)

val_dataset = ImageDataset('../Data/burnasyan_Br.csv', 'val_paths.txt', transform = transform)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)
val_features, val_labels = next(iter(val_dataloader))
val_features = val_features.to(device)
val_labels = val_labels.to(device)

Using cuda device


In [3]:
net = models.resnet18(pretrained=False)
net.fc = nn.Linear(512, 3)

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

net.apply(init_weights)

net = net.to(device)

  torch.nn.init.xavier_uniform(m.weight)


In [5]:
logdir = "../Logits/ResNet_logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(logdir)
res_callback = callback()
weight = torch.tensor([0.33, 0.03, 0.63])
weight = weight.pow(-1)
weight = weight.to(device)
criterion = nn.CrossEntropyLoss(weight=weight, reduction='sum')
optimizer = optim.SGD(net.parameters(), lr=1e-5, momentum=0.9)
#optimizer = optim.Adam(net.parameters(), lr=1e-5, betas=(0.9, 0.999))

In [7]:
optimizer = optim.SGD(net.parameters(), lr=1e-6, momentum=0.9)

In [8]:
epochs_num = 300

for epoch in tqdm_notebook(range(25, epochs_num), desc='epochs'):  # loop over the dataset multiple times
    
    res_callback.on_epoch_begin(epoch) 
    
    running_loss = 0.0
    
    classes = []
    true_classes= []

    for data in tqdm_notebook(train_dataloader, desc='one epoch training'):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        true_classes.append(labels)
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        classes.append(nn.Softmax(dim=1)(outputs))
        outputs = outputs.to(device)
        # print(outputs.shape(), labels.shape())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    running_loss /= len(true_classes)
    
    val_classes = []
    val_true_classes = []
    val_loss = 0.0
    
    for data in tqdm_notebook(val_dataloader, desc='validation'):
        # get the inputs; data is a list of [inputs, labels]
        val_inputs, val_labels = data
        val_true_classes.append(val_labels)
        val_inputs = val_inputs.to(device)
        val_labels = val_labels.to(device)
    
        # forward 
        with torch.no_grad():
            val_outputs = net(val_inputs)
            loss = criterion(val_outputs, val_labels)
            val_loss += loss.item()
        val_classes.append(nn.Softmax(dim=1)(val_outputs))
        
    val_loss /= len(val_true_classes)

    res_callback.on_epoch_end(true_classes, classes, val_true_classes, val_classes,
                          ["Benign", "InSitu", "Invasive"],
                          running_loss, val_loss, writer)
    
    if epoch % 50 == 0:
        torch.save(net.state_dict(), "../Logits/SavedNN/Saved_ResNet/" + str(epoch))

print('Finished Training')

epochs:   0%|          | 0/275 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

validation:   0%|          | 0/657 [00:00<?, ?it/s]

one epoch training:   0%|          | 0/1577 [00:00<?, ?it/s]

KeyboardInterrupt: 