In [1]:
import configparser
from pathlib import Path

from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import distance
from sklearn.model_selection import train_test_split
import torch
from torch import optim, Tensor
from torch.nn import (
    functional as F,
    Module,
    utils
)
from torch.utils.data import DataLoader
from tqdm import tqdm
import gc

In [2]:
import sys
import os

sys.path.append('/kaggle/input')
import model_dr_unet, model_dense_unet, model_unet

In [3]:
class TqdmExtraFormat(tqdm):
    @property
    def format_dict(self):
        d = super(TqdmExtraFormat, self).format_dict
        total_time = d["elapsed"] * (d["total"] or 0) / max(d["n"], 1)
        d.update(total_time=self.format_interval(total_time))
        return d

In [4]:
def plot(img, mask_true, mask_pred, idx):
    img = np.squeeze(img, axis=1)
    mask_true = np.squeeze(mask_true, axis=1)
    mask_pred = np.squeeze(mask_pred, axis=1)
    max_value = img.max()
    img /= max_value

    num_slice = img.shape[0]
    fig, axs = plt.subplots(num_slice, 3)
    axs: list[list[plt.Axes]]
    for i in range(num_slice):
        axs[i][0].imshow(img[i], cmap='gray')
        axs[i][0].set_title(f'z = {i}')
        axs[i][1].imshow(mask_true[i], cmap='gray')
        axs[i][1].set_title(f'ID {idx:03d}\nGround truth')
        axs[i][2].imshow(mask_pred[i], cmap='gray')
        axs[i][2].set_title('AI generated')
    for i in range(num_slice):
        for j in range(3):
            axs[i][j].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    return fig

In [5]:
from pathlib import Path

import nibabel as nib
import numpy as np

class SliceDataset:
    def __init__(self, img_pathes: Path, mask_pathes: Path, intensity_min, intensity_max) -> None:
        self.img_pathes = img_pathes
        self.mask_pathes = mask_pathes
        self.slices = [nib.load(p).shape[-1] for p in self.img_pathes]
        self.cum_slices = np.cumsum(self.slices)
        self.intensity_min = intensity_min
        self.intensity_max = intensity_max

    def __getitem__(self, index: int):
        path_index = np.searchsorted(self.cum_slices, index, side='right')
        if path_index == 0:
            slice_index = index
        else:
            slice_index = index - self.cum_slices[path_index - 1]
        
        mask = np.load(self.mask_pathes[path_index])[:,:,slice_index]
        mask = mask[::2, ::2]
        assert mask.shape == (256, 256), "Resized image shape does not match desired shape"
        mask = mask[np.newaxis, ...]
        
        img = nib.load(self.img_pathes[path_index]).get_fdata()[:,:,slice_index]
        img = img[::2, ::2]
        assert img.shape == (256, 256), "Resized image shape does not match desired shape"
        img = windowing(img, self.intensity_min, self.intensity_max)[np.newaxis, ...]
        
        return img.astype(np.float32), mask.astype(np.float32)
    
    def filter_samples(self, index):
        # Check if the sum of the mask for the given index is greater than 0
        path_index = np.searchsorted(self.cum_slices, index, side='right')
        if path_index == 0:
            slice_index = index
        else:
            slice_index = index - self.cum_slices[path_index - 1]

        mask = np.load(self.mask_pathes[path_index])[:,:,slice_index]
        return np.sum(mask) > 0 # returns true if mask has values
    
    def __len__(self):
        return self.cum_slices[-1]

def windowing(image, min_value, max_value):
    image_new = np.clip(image, min_value, max_value)
    image_new = (image_new - min_value) / (max_value - min_value)
    return image_new

class Subset(SliceDataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

In [6]:
def train(config: configparser.ConfigParser):
    print('Training...')

    img_train_val = Path('/kaggle/input/sarahliu-deeplearning/Full_Data/img_train_val')
    mask_train_val_pathes = sorted(Path('/kaggle/input/sarahliu-deeplearning/Full_Data/mask_train_val').glob('*.npy'))
    img_train_val_pathes = [img_train_val / p.name.replace('.npy', '.nii') for p in mask_train_val_pathes]
        
    print(f'len(img_train_val_pathes) = {len(img_train_val_pathes)}')
    print(f'len(mask_train_val_pathes) = {len(mask_train_val_pathes)}')

    train_size = config['train'].getfloat('train_size')
    random_seed = config['train'].getint('random_seed')

    img_train_pathes, img_val_pathes, mask_train_pathes, mask_val_pathes = train_test_split(
        img_train_val_pathes, mask_train_val_pathes, train_size=0.9, random_state=random_seed
    )

    print(f'len(img_train_pathes) = {len(img_train_pathes)}')
    print(f'len(img_val_pathes) = {len(img_val_pathes)}')
    print(f'len(mask_train_pathes) = {len(mask_train_pathes)}')
    print(f'len(mask_val_pathes) = {len(mask_val_pathes)}')

    intensity_min = config['train'].getint('intensity_min')
    intensity_max = config['train'].getint('intensity_max')
    trainset = SliceDataset(img_train_pathes, mask_train_pathes, intensity_min, intensity_max)
    valset = SliceDataset(img_val_pathes, mask_val_pathes, intensity_min, intensity_max)
    print(f'len(trainset) = {len(trainset)}')
    print(f'len(valset) = {len(valset)}')
    
    train_indices = [idx for idx in range(len(trainset)) if trainset.filter_samples(idx)]
    filtered_trainset = Subset(trainset, train_indices)
    val_indices = [idx for idx in range(len(valset)) if valset.filter_samples(idx)]
    filtered_valset = Subset(valset, val_indices)
    print(f'len(filtered_trainset) = {len(filtered_trainset)}')
    print(f'len(filtered_valset) = {len(filtered_valset)}')

    trainloader = DataLoader(filtered_trainset, batch_size=8, shuffle=True, num_workers=1)
    valloader = DataLoader(filtered_valset, batch_size=8, shuffle=True, num_workers=1)

    print(f'len(trainloader) = {len(trainloader)}')
    print(f'len(valloader) = {len(valloader)}')

    #model_class = getattr(model_dr_unet, "DRUNet")
    model_class = getattr(model_unet, "U_Net")
    #model_class = getattr(model_dense_unet, "Dense_Net")
    model = model_class()
    
    print(f'model.__class__.__name__ = {model.__class__.__name__}')
    print()

    device = config['train']['device']
    model.to(device)
    debug = False

    optimizer = optim.Adam(model.parameters())
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=1 / 10 ** .5, patience=2, verbose=True)
    
    val_metrics = Path("/kaggle/working/val_metrics.txt")
    if val_metrics.is_file():
        val_metrics.unlink() 
    
    for epoch in range(1, 8): 
        if Path("/kaggle/working/temp_model_weights.pth").is_file(): 
            model.load_state_dict(torch.load("/kaggle/working/temp_model_weights.pth"))
        print(f'epoch = {epoch:03d}')

        train_epoch(config, trainloader, model, optimizer)
        torch.cuda.empty_cache()

        val_epoch(config, valloader, model, scheduler)
        torch.cuda.empty_cache()
        print()

        if debug and epoch == 2:
            break
        if optimizer.param_groups[0]['lr'] < 1e-6:
            break

In [7]:
def val_epoch(config: configparser.ConfigParser, loader: DataLoader, model: Module, scheduler: optim.lr_scheduler.ReduceLROnPlateau):
    print(f'length of validation loader = {len(loader)}')
    model.eval()
    device = config['train']['device']
    debug = False # config['train'].getboolean('debug')
    weights_save_path = Path("/kaggle/working/model_weights.pth") # weights_save_path = Path(config['train']['save_path'])
    val_metrics_path = Path("/kaggle/working/val_metrics.txt")
    new_metrics = []
    pdf = PdfPages('val_figures.pdf')
    
    if not val_metrics_path.is_file():
        old_metrics = [0]
    else:
        old_metrics = np.loadtxt(val_metrics_path)
    
    with tqdm(total=len(loader)) as pbar:
        for idx, sample in enumerate(loader):
            sample: tuple[Tensor, ...]
            img, mask_true = sample
            
            model_output = model(img.to(device))
            mask_pred = model_output.cpu().detach().numpy()
            mask_true = (mask_true == 1).numpy()
            dice = 1 - distance.dice(mask_pred.reshape(-1), mask_true.reshape(-1))
            new_metrics.append(dice)
            pbar.update()

            fig = plot(img.numpy(), mask_true, mask_pred, idx)
            fig.set_size_inches(15, 5 * mask_true.shape[0])
            pdf.savefig(fig, bbox_inches='tight')
            plt.close()

            if debug and idx == 1:
                break

    pbar.close()
    pdf.close()
    
    print(f'dice = {np.mean(new_metrics)} ± {np.std(new_metrics)}')
    
    if np.mean(new_metrics) > np.mean(old_metrics) or not weights_save_path.is_file():
        print("Performance improved, saving new weights.")
        torch.save(model.state_dict(), Path("/kaggle/working/model_weights.pth"))
        
    scheduler.step(np.mean(new_metrics))
    np.savetxt(val_metrics_path, new_metrics, fmt='%.5f')

In [8]:
def test(config: configparser.ConfigParser):
    print('Testing...')
    
    img_train_val_test = Path("/kaggle/input/sarahliu-deeplearning/Full_Data/img_test")
    mask_test_pathes = sorted(Path("/kaggle/input/sarahliu-deeplearning/Full_Data/mask_test").glob('*.npy'))
    img_test_pathes = [img_train_val_test / p.name.replace('.npy', '.nii') for p in mask_test_pathes]
    weights_path = "/kaggle/working/model_weights.pth"

    print(f'len(img_test_pathes) = {len(img_test_pathes)}')
    print(f'len(mask_test_pathes) = {len(mask_test_pathes)}')

    intensity_min = config['train'].getint('intensity_min')
    intensity_max = config['train'].getint('intensity_max')

    testset = SliceDataset(img_test_pathes, mask_test_pathes, intensity_min, intensity_max)
    print(f'len(testset) = {len(testset)}')
    
    test_indices = [idx for idx in range(len(testset)) if testset.filter_samples(idx)]
    filtered_testset = Subset(testset, test_indices)
    print(f'len(filtered_testset) = {len(filtered_testset)}')

    testloader = DataLoader(filtered_testset, batch_size=8, shuffle=False, num_workers=1)
    print(f'len(testloader) = {len(testloader)}')
        
    #model_class = getattr(model_dr_unet, "DRUNet")
    model_class = getattr(model_unet, "U_Net")
    #model_class = getattr(model_dense_unet, "Dense_Net")
    model = model_class()
    
    print(f'model.__class__.__name__ = {model.__class__.__name__}')
    print()
    
    model.load_state_dict(torch.load(weights_path))
    device = config['train']['device']
    model.to(device)
    debug = False 
    
    test_metrics_path = Path(config['test']['test_metrics'])
    new_metrics = []
    pdf = PdfPages('figures.pdf')
    
    if not test_metrics_path.is_file():
        old_metrics = [0]
    else:
        old_metrics = np.loadtxt(test_metrics_path)

    with torch.no_grad(), tqdm(total=len(testloader)) as pbar:
        for idx, sample in enumerate(testloader):
            sample: tuple[Tensor, ...]
            img, mask_true = sample
            
            model_output = model(img.to(device))
            mask_pred = model_output.detach().cpu().numpy()
            mask_true = (mask_true == 1).numpy()
            dice = 1 - distance.dice(mask_pred.reshape(-1), mask_true.reshape(-1))
            new_metrics.append(dice)
            pbar.update()
            
            fig = plot(img.numpy(), mask_true, mask_pred, idx)
            fig.set_size_inches(15, 5 * mask_true.shape[0])
            pdf.savefig(fig, bbox_inches='tight')
            plt.close()
            
            if debug and idx == 1:
                break
    pbar.close()
    pdf.close()
    print(new_metrics)
    np.savetxt(test_metrics_path, new_metrics, fmt='%.5f')

In [9]:
!rm -rf /kaggle/working/*
gc.collect()

0

In [10]:
def train_epoch(config: configparser.ConfigParser, loader: DataLoader, model: Module, optimizer: optim.Adam):
    print(f'length of train loader = {len(loader)}')
    model.train()
    device = config['train']['device']
    debug = False 

    loss_values = list()
    
    with tqdm(total=len(loader)) as pbar:
        for batch_idx, sample in enumerate(loader):
            sample: tuple[Tensor, ...]
            img, mask_true = sample
            model_output = model(img.to(device)) 
             
            mask_pred = model_output.float().to(device).requires_grad_()
            mask_true = (mask_true == 1).float().to(device).requires_grad_()

            intersection = torch.sum(mask_pred * mask_true)
            union = torch.sum(mask_pred) + torch.sum(mask_true)
            dice = (2. * intersection + 1e-6) / (union + 1e-6)
            loss = 1. - dice
            
            optimizer.zero_grad()
            loss.backward()
            loss_values.append(loss.item())
            optimizer.step()
            
            pbar.update()
            if debug and batch_idx == 1:
                    break
                    
    torch.save(model.state_dict(), Path("/kaggle/working/temp_model_weights.pth"))
    
    pbar.close()
    print(f'loss = {np.mean(loss_values)} ± {np.std(loss_values)}')

In [None]:
config = configparser.ConfigParser()
config.read('/kaggle/input/sarahliu-deeplearning/config.ini')
train(config)

Training...
len(img_train_val_pathes) = 287
len(mask_train_val_pathes) = 287
len(img_train_pathes) = 258
len(img_val_pathes) = 29
len(mask_train_pathes) = 258
len(mask_val_pathes) = 29
len(trainset) = 8266
len(valset) = 923
len(filtered_trainset) = 2148
len(filtered_valset) = 197
len(trainloader) = 269
len(valloader) = 25
model.__class__.__name__ = U_Net

epoch = 001
length of train loader = 269


100%|██████████| 269/269 [03:15<00:00,  1.38it/s]


loss = 0.4046789166209423 ± 0.2564838437991685
length of validation loader = 25


100%|██████████| 25/25 [01:56<00:00,  4.64s/it]


dice = 0.6902853331030738 ± 0.1016419319906613
Performance improved, saving new weights.

epoch = 002
length of train loader = 269


100%|██████████| 269/269 [02:30<00:00,  1.79it/s]


loss = 0.20222872850177012 ± 0.09381241758798027
length of validation loader = 25


100%|██████████| 25/25 [01:51<00:00,  4.46s/it]


dice = 0.7404911279514405 ± 0.15741061612971627
Performance improved, saving new weights.

epoch = 003
length of train loader = 269


100%|██████████| 269/269 [02:30<00:00,  1.79it/s]


loss = 0.1879809966317783 ± 0.09551823410838048
length of validation loader = 25


100%|██████████| 25/25 [01:55<00:00,  4.61s/it]


dice = 0.7457879616353579 ± 0.15172477796068734
Performance improved, saving new weights.

epoch = 004
length of train loader = 269


100%|██████████| 269/269 [02:30<00:00,  1.79it/s]


loss = 0.1784350876471367 ± 0.09065120804486151
length of validation loader = 25


100%|██████████| 25/25 [01:55<00:00,  4.64s/it]


dice = 0.7697837863805664 ± 0.11152807402764905
Performance improved, saving new weights.

epoch = 005
length of train loader = 269


 63%|██████▎   | 170/269 [01:35<00:55,  1.80it/s]

In [None]:
test(config)