In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
os.chdir('/content/drive/MyDrive/brain-segmentation-pytorch/')

In [3]:
pip install medpy Pillow dill

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
import argparse
import json
import os

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from dataset import BrainSegmentationDataset as Dataset
# from logger import Logger
from loss import DiceLoss
from transform import transforms
from unet import UNet
from utils import log_images, dsc

In [5]:
args={
    'batch_size' : 16,
    'epochs':100,
    'lr' : 0.0001,
    'device' : "cuda:0",
    'workers' : 0,
    'vis_images' : 200,
    'vis_freq' : 10,
    'weights' : "./weights",
    'logs' : "./logs",
    'images' : "./kaggle_3m",
    'image_size' : 256,
    'aug_scale' : 0.05,
    'aug_angle' : 15
}

In [6]:
import dill as pickle

def data_loaders(args):
    # dataset_train, dataset_valid = datasets(args)
    with open('preprosessed_train_data.pkl', 'rb') as f:
        dataset_train = pickle.load(f)
      
    with open('preprosessed_valid_data.pkl', 'rb') as f:
        dataset_valid = pickle.load(f)

    def worker_init(worker_id):
        np.random.seed(42 + worker_id)

    loader_train = DataLoader(
        dataset_train,
        batch_size=args['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=args['workers'],
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=args['batch_size'],
        drop_last=False,
        num_workers=args['workers'],
        worker_init_fn=worker_init,
    )

    return loader_train, loader_valid


def datasets(args):
    train = Dataset(
        images_dir=args['images'],
        subset="train",
        image_size=args['image_size'],
        transform=transforms(scale=args['aug_scale'], angle=args['aug_angle'], flip_prob=0.5),
    )
    valid = Dataset(
        images_dir=args['images'],
        subset="validation",
        image_size=args['image_size'],
        random_sampling=False,
    )
    return train, valid


def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        dsc_list.append(dsc(y_pred, y_true))
        index += num_slices[p]
    return dsc_list


def log_loss_summary(logger, loss, step, prefix=""):
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)


def makedirs(args):
    os.makedirs(args['weights'], exist_ok=True)
    os.makedirs(args['logs'], exist_ok=True)


In [10]:
makedirs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args['device'])

# loader_train, loader_valid = data_loaders(args)
hello, hi = datasets(args)
print('-----')

# loaders = {"train": loader_train, "valid": loader_valid}
# unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
# unet.to(device)

# dsc_loss = DiceLoss()
# best_validation_dsc = 0.0

# optimizer = optim.Adam(unet.parameters(), lr=args['lr'])

# # logger = Logger(args.logs)
# train_losses = []
# eval_losses = []

# # loss_train = []
# # loss_valid = []

# step = 0

reading train images...
preprocessing train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...
normalizing train volumes...
done creating train dataset
reading validation images...
preprocessing validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
normalizing validation volumes...
done creating validation dataset
>>>>> Self.patients <<<<<


AttributeError: ignored

In [12]:
print('>>>>> Self.patients <<<<<')
volume = hello.patients[0]

print(type(volume))
print(type(hello.patients))
print(volume)



>>>>> Self.patients <<<<<
<class 'str'>
<class 'list'>
TCGA_DU_6400_19830518


In [20]:
v, m = hello.volumes[0]

print(f'len : {len(v)}')
print(f'len : {len(v[0])}')
print(f'len : {len(v[0][0])}')
print(f'len : {len(v[-1])}')



v[v < np.max(v) * 0.1] = 0
print('---1---')
print(np.max(v, axis = -1))
print('---2---')
print(np.max(np.max(v, axis=-1), axis=-1))
print('---3---')
print(np.max(np.max(np.max(v, axis=-1), axis=-1), axis=-1))

len : 49
len : 256
len : 256
len : 256
---1---
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 ...

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0

In [None]:
print(f'Size : {volume.size()}')

volume[volume < np.max(volume) * 0.1] = 0
print('---1---')
print(np.max(volume, axis = -1))
print('---2---')
print(np.max(np.max(volume, axis=-1), axis=-1))
print('---3---')
print(np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1))


In [None]:
from IPython.display import clear_output
for epoch in tqdm(range(args['epochs']), total=args['epochs']):
    loss_train = []
    loss_valid = []
    for phase in ["train", "valid"]:
        if phase == "train":
            unet.train()
        else:
            unet.eval()

        validation_pred = []
        validation_true = []

        for i, data in enumerate(loaders[phase]):
            if phase == "train":
                step += 1

            x, y_true = data
            x, y_true = x.to(device), y_true.to(device)
            y_true = y_true/255

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                y_pred = unet(x)
                # print()
                # print('>>>>> 0보다 작은 pred & true <<<<<')
                # print('====== y_pred<0 ======')
                # print(y_pred[y_pred < 0])
                # print('===== y_true<0 =====')
                # print(y_true[y_true < 0])
                # print()
                # print('>>>>> 1보다 큰 pred & true <<<<<')
                # print('====== y_pred >1 ======')
                # print(y_pred[y_pred > 1])
                # print('===== y_true >1 =====')
                # print(y_true[y_true > 1])
                # print('===== 0이나 255가 아닌 y_true =====')
                # print(y_true[(y_true != 0)&(y_true != 255)])

                # print()
                # print('----- pred Min & Max -----')
                # print(f'min :{y_pred.min()} | max :{y_pred.max()}')
                # print('----- true Min & Max -----')
                # print(f'min :{y_true.min()} | max :{y_true.max()}')
                # print()

                loss = dsc_loss(y_pred, y_true)
                # print('===== (pred * true).sum() =====')
                # print(((y_pred[:, 0].contiguous().view(-1)) * (y_true[:, 0].contiguous().view(-1))).sum())
                # print('===== (pred.sum + true.sum() =====')
                # print(((y_pred[:, 0].contiguous().view(-1)).sum()) + ((y_true[:, 0].contiguous().view(-1)).sum()))
                # print('===== loss =====')
                # print(loss)
                # print('-*-'*20)

                if phase == "valid":
                    loss_valid.append(loss.item())
                    y_pred_np = y_pred.detach().cpu().numpy()
                    validation_pred.extend(
                        [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                    )
                    y_true_np = y_true.detach().cpu().numpy()
                    validation_true.extend(
                        [y_true_np[s] for s in range(y_true_np.shape[0])]
                    )
                    # if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                    #     if i * args.batch_size < args.vis_images:
                    #         tag = "image/{}".format(i)
                    #         num_images = args.vis_images - i * args.batch_size
                    #         logger.image_list_summary(
                    #             tag,
                    #             log_images(x, y_true, y_pred)[:num_images],
                    #             step,
                    #         )

                if phase == "train":
                    loss_train.append(loss.item())
                    loss.backward()
                    optimizer.step()

            # if phase == "train" and (step + 1) % 10 == 0:
                # log_loss_summary(logger, loss_train, step)
                # train_epoch_loss = np.mean(loss_train)
                # loss_train = []

        if phase == "valid":
            # log_loss_summary(logger, loss_valid, step, prefix="val_")
            mean_dsc = np.mean(
                dsc_per_volume(
                    validation_pred,
                    validation_true,
                    loader_valid.dataset.patient_slice_index,
                )
            )
            # logger.scalar_summary("val_dsc", mean_dsc, step)
            if mean_dsc > best_validation_dsc:
                best_validation_dsc = mean_dsc
                torch.save(unet.state_dict(), os.path.join(args['weights'], "unet_ours.pt"))
            # loss_valid = []
    train_losses.append(np.mean(loss_train))
    eval_losses.append(np.mean(loss_valid))
    clear_output(True)
    plt.plot(train_losses, label='train loss')
    plt.plot(eval_losses, label='eval loss')
    plt.legend()
    plt.show()
    print('[epoch : {}] Train loss : {}, Eval loss : {}'.format(epoch, round(train_losses[-1], 5), round(eval_losses[-1],5)))

print("Best validation mean DSC: {:4f}".format(best_validation_dsc))