In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib, matplotlib.pyplot as plt
from PIL import Image 
from skimage import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import transforms, datasets, models
import torch.optim as optim
from collections import Counter
import re
import time
import os
import torch.nn.functional as F
import random

from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve, auc, roc_auc_score
from itertools import cycle

In [2]:
if torch.cuda.is_available:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
def setup_seed(seed): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True
    return None
setup_seed(0)

In [4]:
device# = torch.device('cpu')

device(type='cuda')

In [3]:
data = pd.read_csv('DBT_train_resized.csv', index_col = 0)
data.head()

Unnamed: 0,PatientID,StudyUID,Label,View,descriptive_path,classic_path,new_path
0,DBT-P02497,DBT-S00143,0,lcc,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...,Breast-Cancer-Screening-DBT/DBT-P02497/1.2.826...,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...
1,DBT-P02497,DBT-S00143,0,lmlo,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...,Breast-Cancer-Screening-DBT/DBT-P02497/1.2.826...,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...
2,DBT-P02497,DBT-S00143,0,rcc,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...,Breast-Cancer-Screening-DBT/DBT-P02497/1.2.826...,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...
3,DBT-P02497,DBT-S00143,0,rmlo,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...,Breast-Cancer-Screening-DBT/DBT-P02497/1.2.826...,Breast-Cancer-Screening-DBT/DBT-P02497/01-01-2...
4,DBT-P02449,DBT-S05000,0,lcc,Breast-Cancer-Screening-DBT/DBT-P02449/01-01-2...,Breast-Cancer-Screening-DBT/DBT-P02449/1.2.826...,Breast-Cancer-Screening-DBT/DBT-P02449/01-01-2...


In [160]:
# Undersample 0 class
print(Counter(data.Label))
data0 = data[data.Label == 0]
data1 = data[data.Label == 1]
downsample = data0.sample(n = len(data1)*5)
ind = list(downsample.index)
ind = ind + list(data1.index)
newdf = data[data.index.isin(ind)]
print(Counter(newdf.Label))
newdf.to_csv('DBT_train_resized_balanced.csv')

Counter({0: 12760, 1: 143})
Counter({0: 715, 1: 143})


In [8]:
# Oversample 1 class
print('Number of biopsied samples:', np.sum(data['Label'].values))
print('Number of normal samples:', len(data) - np.sum(data['Label'].values))
disease_ratio = np.sum(data['Label'].values) / len(data)
print('Disease Ratio:', disease_ratio)

labels_unique, counts = np.unique(data.Label, return_counts = True)
class_weights = [sum(counts)/c for c in counts]
example_weights = [class_weights[e] for e in data.Label]
sampler = WeightedRandomSampler(example_weights, len(data))

Number of biopsied samples: 143
Number of normal samples: 12760
Disease Ratio: 0.011082693947144074


In [4]:
def get_model(tl_model, train_scratch=False, freeze_weights=False):
    # load the (pretrained) model from torchvision library
    if tl_model == "Resnet18":
        if (train_scratch):
            model_ft = models.resnet18(pretrained=False)
        else:
            model_ft = models.resnet18(pretrained=True)
    if tl_model == "Resnet34":
        if (train_scratch):
            model_ft = models.resnet34(pretrained=False)
        else:
            model_ft = models.resnet34(pretrained=True)
    elif tl_model == "Resnet50":
        if (train_scratch):
            model_ft = models.resnet50(pretrained=False)
        else: 
            model_ft = models.resnet50(pretrained=True)
    elif tl_model == "DenseNet201":
        if (train_scratch):
            model_ft = models.densenet201(pretrained=False)
        else:
            model_ft = models.densenet201(pretrained=True)
    elif tl_model == "DenseNet169":
        if (train_scratch):
            model_ft = models.densenet169(pretrained=False)
        else:
            model_ft = models.densenet169(pretrained=True)
    elif tl_model == "DenseNet121":
        if (train_scratch):
            model_ft = models.densenet121(pretrained=False)
        else:
            model_ft = models.densenet121(pretrained=True)
    else: 
        raise ValueError(f'tl_model={tl_model} is not recognized!')

    # freeze the weights if necessary
    if freeze_weights:
        for param in model_ft.parameters():
            param.requires_grad = False

    # replace last fc layer with binary classification 
    if tl_model in ("DenseNet201", "DenseNet169", "DenseNet121"):
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, 2)
    else:
        num_ftrs = model_ft.fc.in_features
        model_ft.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        model_ft.fc = nn.Linear(num_ftrs, 2)      

    return model_ft

In [5]:
def get_model(tl_model, train_scratch=False, freeze_weights=False):
    # load the (pretrained) model from torchvision library
    if tl_model == "resnet18":
        if (train_scratch):
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
        else:
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
    if tl_model == "resnet34":
        if (train_scratch):
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=False)
        else:
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
    elif tl_model == "resnet50":
        if (train_scratch):
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False)
        else:
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    elif tl_model == "densenet201":
        if (train_scratch):
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'densenet201', pretrained=False)
        else:
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'densenet201', pretrained=True)
    elif tl_model == "densenet169":
        if (train_scratch):
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'densenet169', pretrained=False)
        else:
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'densenet169', pretrained=True)
    elif tl_model == "densenet121":
        if (train_scratch):
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=False)
        else:
            model_ft = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)
    else: 
        raise ValueError(f'tl_model={tl_model} is not recognized!')

    # freeze the weights if necessary
    if freeze_weights:
        for param in model_ft.parameters():
            param.requires_grad = False

    # replace last fc layer with binary classification 
    if tl_model in ("densenet201", "densenet169", "densenet121"):
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, 2)
    else:
        num_ftrs = model_ft.fc.in_features
        model_ft.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        model_ft.fc = nn.Linear(num_ftrs, 2)      

    return model_ft

In [6]:
train_transforms = transforms.Compose([
                                    #transforms.ToPILImage(),
                                    transforms.ToTensor(),
                                    transforms.RandomRotation(20),
                                    transforms.RandomResizedCrop(224, scale=(0.8, 1.2)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ColorJitter(hue=0.5, saturation=0.5, contrast=0.5),
                                    transforms.GaussianBlur(7, sigma=(0.1, 1.0)),
                                    #transforms.Resize((224,224))
                                ])
val_transforms = transforms.Compose([transforms.ToTensor(),
                                     #transforms.Resize((224,224))
                                     ])

class DBT_Dataset(Dataset):
    def __init__(self, df_path, train = False):
        self.df = pd.read_csv(df_path)
        self.train = train
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx, -1]
        img = io.imread(img_name, as_gray=True)
        #img = np.asarray(img)
        img.astype(float)
        
        img = img - np.mean(img)
        img = img / np.maximum(np.std(img), 10**(-5))
        
        if self.train:
            img_tens = train_transforms(img)
        else:
            img_tens = val_transforms(img)
            
        label = self.df['Label'].iloc[idx]
        #label = self.df.loc[idx,'Label'].astype('int')
        label = torch.tensor(label, dtype=torch.long)
        sample = (img_tens.float(), label)
        
        return sample

In [9]:
bs = 16
train_df_path = 'DBT_train_resized.csv'
val_df_path = 'DBT_val_resized.csv'
test_df_path = 'DBT_test_resized.csv'

train_loader = DataLoader(DBT_Dataset(train_df_path, train=True), batch_size=bs, num_workers=8, pin_memory=True, sampler=sampler, drop_last=True)
val_loader = DataLoader(DBT_Dataset(val_df_path), batch_size=bs, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
test_loader = DataLoader(DBT_Dataset(test_df_path), batch_size=bs, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)



In [11]:
# pull pre-trained model
tl_model = "resnet34"
train_scratch = False
freeze_weights = False
model = get_model(tl_model, train_scratch, freeze_weights)

Using cache found in /home/ss14383/.cache/torch/hub/pytorch_vision_v0.10.0


In [51]:
# other models
model = models.vgg16(pretrained=True)
model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
num_ftrs = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_ftrs, 2)

In [13]:
# training params
learning_rate = 1e-4
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
lambda_func = lambda epoch: 0.5 ** epoch
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.1)#lr_lambda=lambda_func,)
loss_fn = nn.CrossEntropyLoss()
epochs = 10
save_path = 'test_model_balanced3'

In [None]:
def model_train(model, train_loader, val_loader, learning_rate, optimizer, scheduler, loss_fn, epochs, save_path):
    start_time = time.time()

    train_loss_return = []
    train_acc_return = []
    val_loss_return = []
    val_acc_return = []
    best_acc = -1
    # Train
    for epoch in range(epochs):
        print('Epoch: {}/{}'.format(epoch, epochs-1))
        print('-'*10)
        pred_list = []
        pred_scores_list = []
        truths_list = []
        loss_list = []
        model.train()
        for idx, (sample) in enumerate(train_loader):
            img = sample[0].to(device)
            labels = sample[1].squeeze(0).to(device)
            outputs = model(img)
            model.zero_grad()
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
            pred_score = nn.Softmax(1)(outputs).cpu().detach().numpy()
            pred_scores_list += pred_score.tolist()
            pred  = np.argmax(pred_score,axis=1)
            pred_list += pred.tolist()
            truths_list += labels.cpu().numpy().tolist()
        scheduler.step(metrics = loss)
        # report performance
        correct_num = (np.array(pred_list) == np.array(truths_list)).sum()
        acc = correct_num/len(truths_list)
        train_acc_return.append(acc)
        train_loss_return.append(np.average(loss_list))
        print('----------Epoch{:2d}/{:2d}----------'.format(epoch+1, epochs))
        print('Train set | Loss: {:6.4f} | Accuracy: {:4.2f}% '.format(np.average(loss_list), acc*100))
        
        # Val
        pred_list = []
        pred_scores_list = []
        truths_list = []
        loss_list = []
        model.eval()
        with torch.no_grad():
            for idx, (sample) in enumerate(val_loader):
                img = sample[0].to(device)
                labels = sample[1].squeeze(0).to(device)
                outputs = model(img)
                loss = loss_fn(outputs, labels)
                loss_list.append(loss.item())
                pred_score = nn.Softmax(1)(outputs).cpu().detach().numpy()
                pred_scores_list += pred_score.tolist()
                pred  = np.argmax(pred_score,axis=1)
                pred_list += pred.tolist()
                truths_list += labels.cpu().numpy().tolist()
            # report performance
            correct_num = (np.array(pred_list) == np.array(truths_list)).sum()
            acc = correct_num/len(truths_list)
            val_acc_return.append(acc)
            val_loss_return.append(np.average(loss_list))
            if acc > best_acc:
                best_acc = acc
                best_model_wts = model.state_dict()
            elapse = time.strftime('%H:%M:%S', time.gmtime(int((time.time() - start_time))))
            print('Val set  | Loss: {:6.4f} | Accuracy: {:4.2f}% | Best ACC: {:6.4f} | time elapse: {:>9}'\
                  .format(np.average(loss_list), acc*100, best_acc*100, elapse))
            save_model(model, best_model_wts, train_loss_return, train_acc_return,\
                       val_loss_return, val_acc_return, save_path=save_path)
            
    return None

def save_model(model, best_model_wts, train_loss_return,train_acc_return,\
               val_loss_return, val_acc_return, save_path):
    state = {'best_model_wts':best_model_wts, 'model':model, \
             'train_loss':train_loss_return, 'train_acc':train_acc_return,\
             'val_loss':val_loss_return, 'val_acc':val_acc_return}
    torch.save(state, save_path+'.pt')
    return None

In [9]:
setup_seed(0)
model = model.to(device)
model_train(model, train_loader, val_loader, learning_rate, optimizer, scheduler, loss_fn, epochs, save_path)

Epoch: 0/9
----------
----------Epoch 1/10----------
Train set | Loss: 0.0799 | Accuracy: 98.87% 
Val set  | Loss: 0.0691 | Accuracy: 98.97% | Best ACC: 98.9685 | time elapse:  00:12:22
Epoch: 1/9
----------
----------Epoch 2/10----------
Train set | Loss: 0.0657 | Accuracy: 98.89% 
Val set  | Loss: 0.1037 | Accuracy: 98.97% | Best ACC: 98.9685 | time elapse:  00:15:12
Epoch: 2/9
----------
----------Epoch 3/10----------
Train set | Loss: 0.0654 | Accuracy: 98.89% 
Val set  | Loss: 0.0794 | Accuracy: 98.97% | Best ACC: 98.9685 | time elapse:  00:18:19
Epoch: 3/9
----------
----------Epoch 4/10----------
Train set | Loss: 0.0651 | Accuracy: 98.89% 
Val set  | Loss: 0.0785 | Accuracy: 98.97% | Best ACC: 98.9685 | time elapse:  00:20:56
Epoch: 4/9
----------
----------Epoch 5/10----------
Train set | Loss: 0.0636 | Accuracy: 98.89% 
Val set  | Loss: 0.0744 | Accuracy: 98.97% | Best ACC: 98.9685 | time elapse:  00:23:32
Epoch: 5/9
----------
----------Epoch 6/10----------
Train set | Loss:

In [None]:
setup_seed(0)
model = model.to(device)
model_train(model, train_loader, val_loader, learning_rate, optimizer, scheduler, loss_fn, epochs, save_path)

Epoch: 0/9
----------
----------Epoch 1/10----------
Train set | Loss: 0.6929 | Accuracy: 51.47% 
Val set  | Loss: 0.6963 | Accuracy: 36.25% | Best ACC: 36.2500 | time elapse:  00:00:53
Epoch: 1/9
----------
----------Epoch 2/10----------
Train set | Loss: 0.6928 | Accuracy: 51.84% 
Val set  | Loss: 0.6933 | Accuracy: 45.60% | Best ACC: 45.5978 | time elapse:  00:01:54
Epoch: 2/9
----------
----------Epoch 3/10----------
Train set | Loss: 0.6927 | Accuracy: 51.81% 
Val set  | Loss: 0.6945 | Accuracy: 42.09% | Best ACC: 45.5978 | time elapse:  00:02:49
Epoch: 3/9
----------
----------Epoch 4/10----------
Train set | Loss: 0.6926 | Accuracy: 51.93% 
Val set  | Loss: 0.6868 | Accuracy: 63.23% | Best ACC: 63.2337 | time elapse:  00:03:44
Epoch: 4/9
----------
----------Epoch 5/10----------
Train set | Loss: 0.6924 | Accuracy: 52.62% 
Val set  | Loss: 0.6899 | Accuracy: 54.24% | Best ACC: 63.2337 | time elapse:  00:04:46
Epoch: 5/9
----------
----------Epoch 6/10----------
Train set | Loss:

In [None]:
# load best weights
path = 'test_model_balanced2.pt'
model.load_state_dict(torch.load(path)['best_model_wts'])

In [None]:
# Evaluation
def model_eval(model, test_loader):
    pred_list = []
    pred_scores_list = []
    truths_list = []
    
    model.eval()
    with torch.no_grad():
        for i, (sample) in enumerate(test_loader):
            img = sample[0].to(device)
            labels = sample[1].to(device)#.squeeze(1).to(device)
            outputs = model(img)
            pred_score = nn.Softmax(1)(outputs).cpu().detach().numpy()
            pred_scores_list += pred_score.tolist()
            pred  = np.argmax(pred_score,axis=1)
            pred_list += pred.tolist()
            truths_list += labels.cpu().numpy().tolist()
    return np.array(pred_list), np.array(pred_scores_list), np.array(truths_list)

In [None]:
pred, pred_scores, truths = model_eval(model, test_loader)

In [None]:
print('Prediction:',pred[:10])
print('Truth:     ',truths[:10])
print('Prediction Probability:\n',pred_scores[:10])

In [None]:
def ROC_curve(y_test, y_score):
    fpr, tpr, roc_auc = {},{},{}
    n_classes = y_test.shape[1]
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    colors = cycle(['aqua', 'darkorange'])
    lw = 2
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                 ''.format(i, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Some extension of Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.show()
    return None

In [None]:
truth_is_0 = (truths == 0).astype(int)
truth_is_1 = (truths == 1).astype(int)
y_test = np.array([truth_is_0, truth_is_1])
y_test = np.transpose(y_test,(1,0))
print('Truth = 0:',truth_is_0[:10])
print('Truth = 1:',truth_is_1[:10])
print('Label for ROC:')
print(y_test[:,:10])
print('Prediction Probability for ROC:')
print(pred_scores[:,:10])

In [None]:
ROC_curve(y_test,pred_scores)

In [None]:
roc_auc_score(truths,pred)

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues):
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'
    cm = confusion_matrix(y_true, y_pred)
    #classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        pass

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return None

In [None]:
classes = ['Normal','Disease']
plot_confusion_matrix(truths, pred, classes)