In [1]:
import torch
import pandas as pd
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch import nn
import multiprocessing as mp
import os
from tqdm import tqdm
from torch import optim
import wandb

In [2]:
multiple_gpus = False
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        multiple_gpus = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
train_csv = pd.read_csv('../../melanoma_data/train.csv')

In [4]:
train_csv['updated_paths'] = train_csv['image_name'].apply(lambda x: '../../melanoma_data/jpeg/train/' + x + '.jpg')

In [5]:
def split_datasets(dataset, test_size = 0.01):
    train, test = train_test_split(dataset, test_size = test_size, random_state=42)
    train, val = train_test_split(train, test_size = test_size, random_state=42)
    return train, val, test

In [6]:
def check_if_exists(path):
    if os.path.exists(path) == False:
        return False
    else:
        return True

In [7]:
with mp.Pool(10) as p:
    returns = list(p.map(check_if_exists, train_csv['updated_paths'].values.tolist()))

In [8]:
count = 0
for x in returns:
    if x == False:
        count += 1

In [9]:
def get_model():
    res50 = torchvision.models.resnet50()
    res50.fc = nn.Linear(2048, 1)
    return res50

In [10]:
model = get_model()

In [11]:
class PrepData(Dataset):
    def __init__(self, csv_file, img_size):
        self.csv_file = csv_file
        self.paths = csv_file['updated_paths'].values.tolist()
        self.labels = csv_file['target'].values.tolist()
        self.img_size = img_size

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        img = torchvision.io.read_file(self.paths[idx])
        img = torchvision.io.decode_jpeg(img)
        img = torchvision.transforms.functional.resize(img, (self.img_size[0], self.img_size[1]))
        img = img / 255
        if self.labels[idx] == 'benign':
            return img, torch.Tensor([0]).float()
        else:
            return img, torch.Tensor([1]).float()

In [12]:
train, val, test = split_datasets(train_csv)

In [13]:
train_set = PrepData(train[:1000], img_size=(256, 256))
val_set = PrepData(val, img_size=(256, 256))
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=10, prefetch_factor=2)
val_loader = DataLoader(val_set, batch_size=32, shuffle=True, num_workers=10, prefetch_factor=2)



In [17]:
# create a loss function here
def bce_loss(inputs, targets):
    sig = nn.Sigmoid()
    bce_loss = nn.BCELoss()
    inputs = sig(inputs)
    loss = bce_loss(inputs, targets)
    return loss

def cal_prec(inputs, targets, thres=0.5):
    # Formula for Prec = TP / (TP + FP)
    sig = nn.Sigmoid()
    inputs = torch.where(sig(inputs) > thres, 1.0, 0.0)
    tp = torch.sum(torch.logical_and(inputs == 1.0, targets == 1.0))
    fp = torch.sum(torch.logical_and(inputs == 1.0, targets == 0.0))
    return tp / (tp + fp)

def cal_rec(inputs, targets, thres=0.5):
    # Formula for Prec = TP / (TP + FN)
    sig = nn.Sigmoid()
    inputs = torch.where(sig(inputs) > thres, 1.0, 0.0)
    tp = torch.sum(torch.logical_and(inputs == 1.0, targets == 1.0))
    fn = torch.sum(torch.logical_and(inputs == 0.0, targets == 1.0))
    return tp / (tp + fn)

In [18]:
def train_function(model, epochs, train_loader, val_loader, load_weights):
    wandb.init(
        project='melanoma-classification'
    )
    data_loaders = {
        'train' : train_loader,
        'val' : val_loader
    }
    # Define optimizer
    # define how model sent to cuda
    if load_weights != None:
        model.load_state_dict(torch.load(load_weights)['model_state_dict'])
    optimizer = optim.SGD(model.parameters(), lr=0.001)
    if next(model.parameters()).is_cuda == False:
        if multiple_gpus == True:
            model = nn.DataParallel(model)
        model = model.to(device)
    for epoch in range(epochs):
        train_loss, train_prec, train_rec = 0.0, 0.0, 0.0
        val_loss, val_prec, val_rec = 0.0, 0.0, 0.0
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss, running_prec, running_rec = 0.0, 0.0, 0.0
            with tqdm(data_loaders[phase], unit='batch') as tepoch:
                for img, label in tepoch:
                    tepoch.set_description(f'Epoch: {epoch}')
                    img = img.to(device)
                    label = label.to(device)                
                    optimizer.zero_grad()
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(img)
                        loss = bce_loss(outputs, label)
                        prec = cal_prec(outputs, label)
                        rec = cal_prec(outputs, label)
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    running_loss += loss.item()
                    running_prec += prec.item()
                    running_rec += rec.item()
                    tepoch.set_postfix(loss = loss.item(), prec = prec.item(), rec = rec.item())
            if phase == 'train':
                train_loss += running_loss / len(data_loaders['train'])
                train_prec += running_prec / len(data_loaders['train'])
                train_rec += running_rec / len(data_loaders['train'])
                print(f'For {phase} phase : {train_loss}')
                print(f'For {phase} phase Prec : {train_prec}')
                print(f'For {phase} phase Rec : {train_rec}')
            else:
                val_loss += val_loss / len(data_loaders['val'])
                val_prec += running_prec / len(data_loaders['val'])
                val_rec += running_rec / len(data_loaders['val'])
                print(f'For {phase} phase : {val_loss}')
                print(f'For {phase} phase Prec : {val_prec}')
                print(f'For {phase} phase Rec : {val_rec}')
        wandb.log({
                    'train_loss' : train_loss,
                    'val_loss' : val_loss,
                    'train_prec' : train_prec,
                    'val_prec' : val_prec,
                    'train_rec' : train_rec,
                    'val_rec' : val_rec
                })

In [19]:
train_function(get_model(), 5, train_loader, val_loader, None)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669513750002807, max=1.0…

Epoch: 0: 100%|██████████| 32/32 [00:39<00:00,  1.23s/batch, loss=0.0157, prec=1, rec=1]


For train phase : 0.09427138129831292
For train phase Prec : 1.0
For train phase Rec : 1.0


Epoch: 0: 100%|██████████| 11/11 [00:13<00:00,  1.21s/batch, loss=0.0699, prec=1, rec=1]


For val phase : 0.0
For val phase Prec : 1.0
For val phase Rec : 1.0


Epoch: 1: 100%|██████████| 32/32 [00:33<00:00,  1.05s/batch, loss=0.00528, prec=1, rec=1]


For train phase : 0.016081289082649164
For train phase Prec : 1.0
For train phase Rec : 1.0


  0%|          | 0/11 [00:04<?, ?batch/s]


KeyboardInterrupt: 