In [None]:
import torch
import itertools
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
import os

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
train_data = os.listdir('/scratch/liy31/myjupyter/capstone/train_data/')
train_label = os.listdir('/scratch/liy31/myjupyter/capstone/train_label/')

selected_data = set(['_'.join(_.split('_')[:4]) for _ in train_data])

#np.unique([_[:13] for _ in selected_data], return_counts=True)

train_list = [[(_ + '_data_' + str(idx) + '.npy', _ + '_target_' + str(idx) + '.npy') for idx in range(48)] 
              for _ in selected_data if _[:13] in ['121919_Myo089', '121919_Myo253', '121919_Myo368']]
validation_list = [[(_ + '_data_' + str(idx) + '.npy', _ + '_target_' + str(idx) + '.npy') for idx in range(48)] 
                   for _ in selected_data if _[:13] in ['121919_Myo208', '121919_Myo388']]
test_list = [[(_ + '_data_' + str(idx) + '.npy', _ + '_target_' + str(idx) + '.npy') for idx in range(48)] 
             for _ in selected_data if _[:13] in ['121919_Myo231', '121919_Myo511']]
train_list = list(itertools.chain(*train_list))
validation_list = list(itertools.chain(*validation_list))
test_list = list(itertools.chain(*test_list))

In [None]:
from torchvision import transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(3),
    transforms.ColorJitter(brightness=0.03, contrast=0.03, saturation=0.03, hue=0.03),
    transforms.ToTensor(),
])

class CustomDataset(Dataset):
    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):

        input = np.load('/scratch/liy31/myjupyter/capstone/train_data/' + self.file_list[idx][0])
        target = np.load('/scratch/liy31/myjupyter/capstone/train_label/' + self.file_list[idx][1])
        input = torch.from_numpy(input).unsqueeze(0)
        target = torch.from_numpy(target).unsqueeze(0)

        if self.transform:
            input = self.transform(input)

        return (input, target)

In [None]:
dataloader = {}
dataloader['train'] = DataLoader(CustomDataset(train_list,transform), batch_size=16, shuffle=True, num_workers=8, drop_last=False)
dataloader['validation'] = DataLoader(CustomDataset(validation_list,transform), batch_size=16, shuffle=True, num_workers=8, drop_last=False)
dataloader['test'] = DataLoader(CustomDataset(test_list,transform), batch_size=16, shuffle=True, num_workers=8, drop_last=False)


datasize = {'train': len(train_list), 'validation': len(validation_list), 'test': len(test_list)}

In [None]:
def train_model(model, criterion1, criterion2, optimizer, scheduler, thres=0.5, num_epochs=25):
    since = time.time()
    
    init_model_wts = copy.deepcopy(model.state_dict())
    best_pos_model_wts = copy.deepcopy(model.state_dict())
    best_neg_model_wts = copy.deepcopy(model.state_dict())
    best_model_wts = copy.deepcopy(model.state_dict())
    best_pos_acc = 0.0
    best_neg_acc = 0.0
    best_acc = 0.0

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_pos_corrects = 0
            running_neg_corrects = 0
            running_pos = 0
            running_neg = 0

            # Iterate over data.
            for inputs, target in dataloader[phase]:
                inputs = inputs.to(device)
                target = target.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    pred = (outputs > thres) * 1
                    loss1 = criterion1(outputs, target)
                    loss2 = criterion2(outputs, target)
                    loss = loss1 + loss2

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item()
                running_pos += torch.sum(target)
                running_neg += torch.sum(~target.bool())
                running_pos_corrects += torch.sum(pred[target.bool()])
                running_neg_corrects += torch.sum(~target.bool()) - torch.sum(pred[~target.bool()])
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / datasize[phase]
            epoch_pos_acc = running_pos_corrects.double() / running_pos.double()
            epoch_neg_acc = running_neg_corrects.double() / running_neg.double()
            epoch_acc = (epoch_pos_acc + epoch_neg_acc) / 2

            print('{} Loss: {:.4f} Boundary Acc: {:.4f} Background Acc: {:.4f} Overall Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_pos_acc, epoch_neg_acc, epoch_acc))

            # deep copy the model
            if phase == 'validation':
                if epoch_pos_acc > best_pos_acc:
                    best_pos_acc = epoch_pos_acc
                    best_pos_model_wts = copy.deepcopy(model.state_dict())
                    torch.save({'model': model, 'optimizer': optimizer, 'scheduler': scheduler.state_dict(), 'loss1': criterion1.state_dict(), 
                                'loss2': criterion2.state_dict(), 'thres': thres, 'epoch': epoch, 'initial_state': init_model_wts}, 
                               '/scratch/liy31/myjupyter/capstone/model_aug/unet_fa_{:.3f}_ba_{:.3f}.pt'.format(epoch_pos_acc, epoch_neg_acc))
                if epoch_neg_acc > best_neg_acc:
                    best_neg_acc = epoch_neg_acc
                    best_neg_model_wts = copy.deepcopy(model.state_dict())
                    torch.save({'model': model, 'optimizer': optimizer, 'scheduler': scheduler.state_dict(), 'loss1': criterion1.state_dict(), 
                                'loss2': criterion2.state_dict(), 'thres': thres, 'epoch': epoch, 'initial_state': init_model_wts}, 
                               '/scratch/liy31/myjupyter/capstone/model_aug/unet_fa_{:.3f}_ba_{:.3f}.pt'.format(epoch_pos_acc, epoch_neg_acc))
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save({'model': model, 'optimizer': optimizer, 'scheduler': scheduler.state_dict(), 'loss1': criterion1.state_dict(), 
                                'loss2': criterion2.state_dict(), 'thres': thres, 'epoch': epoch, 'initial_state': init_model_wts}, 
                               '/scratch/liy31/myjupyter/capstone/model_aug/unet_fa_{:.3f}_ba_{:.3f}.pt'.format(epoch_pos_acc, epoch_neg_acc))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Boundary Acc: {:4f}, Best Background Acc: {:4f}, Best Overall Acc: {:4f}'.format(best_pos_acc, best_neg_acc, best_acc))

In [None]:
class DiceLoss(nn.Module):
    # reduction = 'sum'
    def __init__(self, device):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0
        self.one = torch.ones(1, requires_grad=True, device=device)

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return self.one - dsc


In [None]:

class tversky(nn.Module):
    def __init__(self,device):
        super(tversky, self).__init__()
        self.alpha = 0.7
        self.smooth = 1.0
        self.one = torch.ones(1, requires_grad=True, device=device)
        
    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_true_pos = y_pred[:, 0].contiguous().view(-1)
        y_pred_pos = y_true[:, 0].contiguous().view(-1)
        true_pos = (y_true_pos * y_pred_pos).sum()
        false_neg = (y_true_pos * (1 - y_pred_pos)).sum()
        false_pos = ((1 - y_true_pos) * y_pred_pos).sum()
        tv = (true_pos + self.smooth) / (true_pos + self.alpha * false_neg + (1 - self.alpha) * false_pos + self.smooth)
        return self.one - tv



In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=1, out_channels=1, init_features=32, pretrained=False) # https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/
model = model.to(device)
criterion1 = DiceLoss(device)
criterion2 = nn.BCEWithLogitsLoss(reduction='sum', pos_weight=torch.tensor([33.]).to(device))
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

Using cache found in /home/liy31/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


In [None]:
from tqdm import tqdm
train_model(model, criterion1, criterion2, optimizer, scheduler)

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 0/24
----------
train Loss: 62349.0790 Boundary Acc: 0.8126 Background Acc: 0.8764 Overall Acc: 0.8445
validation Loss: 61858.9055 Boundary Acc: 0.9312 Background Acc: 0.8248 Overall Acc: 0.8780


  4%|▍         | 1/25 [01:39<39:53, 99.73s/it]


Epoch 1/24
----------
train Loss: 62260.1121 Boundary Acc: 0.8151 Background Acc: 0.8778 Overall Acc: 0.8464
validation Loss: 61466.3299 Boundary Acc: 0.8958 Background Acc: 0.8460 Overall Acc: 0.8709


  8%|▊         | 2/25 [03:19<38:13, 99.74s/it]


Epoch 2/24
----------
train Loss: 62131.6857 Boundary Acc: 0.8198 Background Acc: 0.8795 Overall Acc: 0.8497


 12%|█▏        | 3/25 [04:59<36:37, 99.88s/it]

validation Loss: 61497.5682 Boundary Acc: 0.8974 Background Acc: 0.8451 Overall Acc: 0.8712

Epoch 3/24
----------
train Loss: 62057.0416 Boundary Acc: 0.8203 Background Acc: 0.8813 Overall Acc: 0.8508
validation Loss: 61513.8020 Boundary Acc: 0.9175 Background Acc: 0.8398 Overall Acc: 0.8787


 16%|█▌        | 4/25 [06:41<35:07, 100.34s/it]


Epoch 4/24
----------
train Loss: 61939.9777 Boundary Acc: 0.8255 Background Acc: 0.8827 Overall Acc: 0.8541
validation Loss: 61066.8248 Boundary Acc: 0.9030 Background Acc: 0.8543 Overall Acc: 0.8787


 20%|██        | 5/25 [08:22<33:33, 100.68s/it]


Epoch 5/24
----------
train Loss: 61913.4707 Boundary Acc: 0.8272 Background Acc: 0.8828 Overall Acc: 0.8550
validation Loss: 60900.4193 Boundary Acc: 0.8549 Background Acc: 0.8719 Overall Acc: 0.8634


 24%|██▍       | 6/25 [10:03<31:56, 100.87s/it]


Epoch 6/24
----------
train Loss: 61934.0131 Boundary Acc: 0.8252 Background Acc: 0.8828 Overall Acc: 0.8540
validation Loss: 60967.0471 Boundary Acc: 0.8243 Background Acc: 0.8785 Overall Acc: 0.8514


 28%|██▊       | 7/25 [11:46<30:23, 101.29s/it]


Epoch 7/24
----------
train Loss: 61887.9305 Boundary Acc: 0.8273 Background Acc: 0.8834 Overall Acc: 0.8553
validation Loss: 61215.0423 Boundary Acc: 0.9151 Background Acc: 0.8473 Overall Acc: 0.8812


 32%|███▏      | 8/25 [13:29<28:50, 101.77s/it]


Epoch 8/24
----------
train Loss: 61868.7279 Boundary Acc: 0.8289 Background Acc: 0.8834 Overall Acc: 0.8561


 36%|███▌      | 9/25 [15:11<27:12, 102.00s/it]

validation Loss: 60961.4047 Boundary Acc: 0.8933 Background Acc: 0.8598 Overall Acc: 0.8766

Epoch 9/24
----------
train Loss: 61835.3312 Boundary Acc: 0.8280 Background Acc: 0.8845 Overall Acc: 0.8562


 40%|████      | 10/25 [16:57<25:45, 103.04s/it]

validation Loss: 60933.9924 Boundary Acc: 0.8842 Background Acc: 0.8629 Overall Acc: 0.8735

Epoch 10/24
----------
train Loss: 61850.5728 Boundary Acc: 0.8291 Background Acc: 0.8837 Overall Acc: 0.8564


 44%|████▍     | 11/25 [18:39<23:57, 102.71s/it]

validation Loss: 60916.0613 Boundary Acc: 0.8829 Background Acc: 0.8638 Overall Acc: 0.8734

Epoch 11/24
----------
train Loss: 61841.2839 Boundary Acc: 0.8298 Background Acc: 0.8837 Overall Acc: 0.8568


 48%|████▊     | 12/25 [20:22<22:19, 103.03s/it]

validation Loss: 60909.6710 Boundary Acc: 0.8786 Background Acc: 0.8652 Overall Acc: 0.8719

Epoch 12/24
----------
train Loss: 61848.2024 Boundary Acc: 0.8294 Background Acc: 0.8837 Overall Acc: 0.8566


 52%|█████▏    | 13/25 [22:04<20:32, 102.69s/it]

validation Loss: 60898.0980 Boundary Acc: 0.8899 Background Acc: 0.8623 Overall Acc: 0.8761

Epoch 13/24
----------
train Loss: 61878.8834 Boundary Acc: 0.8287 Background Acc: 0.8831 Overall Acc: 0.8559


 56%|█████▌    | 14/25 [23:45<18:44, 102.19s/it]

validation Loss: 60887.4778 Boundary Acc: 0.8915 Background Acc: 0.8621 Overall Acc: 0.8768

Epoch 14/24
----------
train Loss: 61825.5771 Boundary Acc: 0.8329 Background Acc: 0.8832 Overall Acc: 0.8580


 60%|██████    | 15/25 [25:26<16:56, 101.63s/it]

validation Loss: 60890.5794 Boundary Acc: 0.8902 Background Acc: 0.8624 Overall Acc: 0.8763

Epoch 15/24
----------
train Loss: 61848.5990 Boundary Acc: 0.8302 Background Acc: 0.8834 Overall Acc: 0.8568


 64%|██████▍   | 16/25 [27:06<15:11, 101.31s/it]

validation Loss: 60891.8519 Boundary Acc: 0.8933 Background Acc: 0.8615 Overall Acc: 0.8774

Epoch 16/24
----------
train Loss: 61822.4737 Boundary Acc: 0.8324 Background Acc: 0.8834 Overall Acc: 0.8579


In [None]:
#model_load = torch.load('/scratch/liy31/myjupyter/capstone/model_aug//unet_fa_{:.3f}_ba_{:.3f}.pt'.format(0.931, 0.825))
model_load = torch.load('/scratch/liy31/myjupyter/capstone/model_aug//unet_fa_{:.3f}_ba_{:.3f}.pt'.format(0.915, 0.847))
#model_load = torch.load('/scratch/liy31/myjupyter/capstone/model_aug//unet_fa_{:.3f}_ba_{:.3f}.pt'.format(0.855, 0.872))



In [None]:
model = model_load['model'].to(device)
model.eval()   # Set model to evaluate mode

running_pos_corrects = 0
running_neg_corrects = 0
running_pos = 0
running_neg = 0

# Iterate over data.
for inputs, target in dataloader['validation']:
    inputs = inputs.to(device)
    target = target.to(device)

    outputs = model(inputs)
    pred = nn.Sigmoid()(outputs) > 0.5

    # statistics
    running_pos += torch.sum(target)
    running_neg += torch.sum(~target.bool())
    running_pos_corrects += torch.sum(pred[target.bool()])
    running_neg_corrects += torch.sum(~target.bool()) - torch.sum(pred[~target.bool()])

epoch_pos_acc = running_pos_corrects.double() / running_pos.double()
epoch_neg_acc = running_neg_corrects.double() / running_neg.double()
epoch_acc = (epoch_pos_acc + epoch_neg_acc) / 2
print('Validation Boundary Acc: {:4f}, Background Acc: {:4f}, Overall Acc: {:4f}'.format(epoch_pos_acc, epoch_neg_acc, epoch_acc))

Validation Boundary Acc: 0.997564, Background Acc: 0.212516, Overall Acc: 0.605040


In [None]:
input, target = next(iter(dataloader['validation']))
#output = nn.Sigmoid()(model.cpu()(input))
output = nn.Sigmoid()(model.cpu()(input))
output = (output > 0.5).float()

plt.figure(figsize=(24,18))
for _ in range(8):
    plt.subplot(6,8,_+1)
    plt.imshow(input[_].squeeze().numpy())
    plt.subplot(6,8,_+9)
    plt.imshow(target[_].squeeze().numpy())
    plt.subplot(6,8,_+17)
    plt.imshow(output[_].squeeze().numpy())
for _ in range(8,16):
    plt.subplot(6,8,_+17)
    plt.imshow(input[_].squeeze().numpy())
    plt.subplot(6,8,_+25)
    plt.imshow(target[_].squeeze().numpy())
    plt.subplot(6,8,_+33)
    plt.imshow(output[_].squeeze().numpy())
plt.show()

In [None]:
pwd

'/scratch/liy31'