In [1]:
from dataclasses import dataclass

import torch
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import InterpolationMode
from utils import *
import glob
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msetupishe[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
@dataclass(frozen=True)
class DatasetConfig:
    USED_CLASSES = [0, 1, 2, 3, 6, 7, 8, 9]

    IMAGE_SIZE: tuple[int,int] = (448, 448) # W, H
    BACKGROUND_CLS_ID: int = 0
    DATASET_PATH: str = 'data/default_dataset'
    AUGS = [A.HorizontalFlip(p=0.5), 
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(scale_limit=0.12, rotate_limit=0.15, shift_limit=0.12, p=0.5),
            ]


In [4]:
def create_class_mapping(used_classes):
    class_mapping = np.zeros(len(labels_dict), dtype=int)
    for new_class, original_class in enumerate(used_classes):
        class_mapping[original_class] = new_class
    return class_mapping

In [5]:
class_mapping = create_class_mapping(DatasetConfig.USED_CLASSES)

In [6]:
class_mapping

array([0, 1, 2, 3, 0, 0, 4, 5, 6, 7, 0, 0, 0])

In [7]:
class OrgansDataset(Dataset):
    def __init__(self,
                dataset_path: str, 
                img_size: int,
                augs: List | None = None,
                cache: bool = False,
                clip_min: int | None = None,
                clip_max: int | None = None,
                ):
        super().__init__()
        self.use_cache = cache
        self.img_size = img_size
        self.images = []
        self.labels = []
        self.clip_min = clip_min
        self.clip_max = clip_max

        for img_path in glob.glob(dataset_path + '/**/*img.npy', recursive=True):
            lbl_path = img2label(img_path)
            self.images.append(load_npy(img_path) if self.use_cache else img_path)
            self.labels.append(load_npy(img_path) if self.use_cache else lbl_path)

        transforms = []
        if augs is not None:
            transforms.extend(augs)
        transforms.extend([
                        A.Resize(self.img_size, self.img_size, always_apply=True),
                        ToTensorV2(always_apply=True)
                    ])
        self.transforms = A.Compose(transforms)

        
        
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index: int):
        image = self.images[index]
        label = self.labels[index]

        if not self.use_cache:
            image = load_npy(image)
            label = load_npy(label)
        
        image = normalize(image, 
                          min_val=self.clip_min,
                          max_val=self.clip_max,
                          )
        image = np.expand_dims(image, 2)
        label = class_mapping[label.astype(int)]
        transformed = self.transforms(image=image, mask=label)
        image, label = transformed["image"], transformed["mask"].to(torch.long)
        return image, label

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim

In [9]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # encoder (downsampling)
        # Each enc_conv/dec_conv block should look like this:
        # nn.Sequential(
        #     nn.Conv2d(...),
        #     ... (2 or 3 conv layers with relu and batchnorm),
        # )
        self.pooling = F.max_pool2d
        self.pool_params = {"kernel_size":2, "stride":2}
        self.enc_conv0 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # bottleneck
        self.bottleneck_conv = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        # decoder (upsampling)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.dec_conv0 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.dec_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.dec_conv2 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.dec_conv3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3,
                      padding = 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=out_channels, kernel_size=3,
                      padding = 1)
        )

    def forward(self, x):
        # encoder                                                  (size, features)

        e0 = self.enc_conv0(x) #(256, 32)
        e1 = self.enc_conv1(self.pooling(e0, **self.pool_params)) #(128, 64)
        e2 = self.enc_conv2(self.pooling(e1, **self.pool_params)) #(64, 128)
        e3 = self.enc_conv3(self.pooling(e2, **self.pool_params)) #(32, 256)

        b = self.bottleneck_conv(self.pooling(e3, **self.pool_params)) #(16, 256)


        d0 = self.dec_conv0(torch.cat((self.upsample(b), e3), 1)) #(32, 128)
        d1 = self.dec_conv1(torch.cat((self.upsample(d0), e2), 1)) #(64, 64)
        d2 = self.dec_conv2(torch.cat((self.upsample(d1), e1), 1)) #(128, 32)
        d3 = self.dec_conv3(torch.cat((self.upsample(d2), e0), 1)) #(256, 1)

        return d3

In [10]:
def dice_coef_loss(predictions, ground_truths, num_classes=2, dims=(1, 2), smooth=1e-8, ce=False):
    """Smooth Dice coefficient + Cross-entropy loss function."""
    ground_truth_oh = F.one_hot(ground_truths, num_classes=num_classes)
    prediction_norm = F.softmax(predictions, dim=1).permute(0, 2, 3, 1)
    intersection = (prediction_norm * ground_truth_oh).sum(dim=dims)
    summation = prediction_norm.sum(dim=dims) + ground_truth_oh.sum(dim=dims)

    dice = (2.0 * intersection + smooth) / (summation + smooth)
    dice_mean = dice.mean()
    loss = 1.0 - dice_mean
    if ce:
        CE = F.cross_entropy(predictions, ground_truths)
        loss += CE

    return loss

In [11]:
import torch

In [12]:
a = torch.arange(2*3*5).reshape(2, 3, 5)/10
a

tensor([[[0.0000, 0.1000, 0.2000, 0.3000, 0.4000],
         [0.5000, 0.6000, 0.7000, 0.8000, 0.9000],
         [1.0000, 1.1000, 1.2000, 1.3000, 1.4000]],

        [[1.5000, 1.6000, 1.7000, 1.8000, 1.9000],
         [2.0000, 2.1000, 2.2000, 2.3000, 2.4000],
         [2.5000, 2.6000, 2.7000, 2.8000, 2.9000]]])

In [13]:
F.softmax(a, dim=1)


tensor([[[0.1863, 0.1863, 0.1863, 0.1863, 0.1863],
         [0.3072, 0.3072, 0.3072, 0.3072, 0.3072],
         [0.5065, 0.5065, 0.5065, 0.5065, 0.5065]],

        [[0.1863, 0.1863, 0.1863, 0.1863, 0.1863],
         [0.3072, 0.3072, 0.3072, 0.3072, 0.3072],
         [0.5065, 0.5065, 0.5065, 0.5065, 0.5065]]])

In [14]:
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassF1Score

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [16]:
def fit_epoch(model, train_loader, criterion, optimizer, metric):
    model.train()
    running_loss = 0.0
    running_metric = 0
    processed_data = 0

    count = 0
    for inputs, labels in tqdm(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels, num_classes=len(DatasetConfig.USED_CLASSES))
        loss.backward()
        optimizer.step()
        preds = F.softmax(outputs, dim=1)
        running_loss += loss.item() * inputs.size(0)
        running_metric += metric(preds, labels)
        processed_data += inputs.size(0)
        count += 1

    train_loss = running_loss / processed_data
    train_metric = running_metric.cpu().numpy() / count
    return train_loss, train_metric

In [17]:
from sklearn.model_selection import train_test_split

In [18]:
@torch.no_grad()
def eval_epoch(model, val_loader, criterion, metric):
    model.eval()
    running_loss = 0.0
    running_metric = 0
    processed_size = 0

    count = 0
    for inputs, labels in tqdm(val_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels, num_classes=len(DatasetConfig.USED_CLASSES))
        preds = F.softmax(outputs, dim=1)
        running_loss += loss.item() * inputs.size(0)
        running_metric += metric(preds, labels)
        processed_size += inputs.size(0)
        count += 1

    val_loss = running_loss / processed_size
    val_metric = running_metric.double() / count
    return val_loss, val_metric.cpu().numpy()

In [19]:
from tqdm import tqdm

In [20]:
def train(model, epochs, train_loader, val_loader, optim, criterion, metric, learning_rate=0.001):

    # best_model_wts = model.state_dict()
    # best_metric = 0.0
    # best_epoch = 0
    history = []
    log_template = "\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
    val_loss {v_loss:0.4f} train_metric {t_metric:0.4f} val_metric {v_metric:0.4f}"

    
    opt = optim(model.parameters(), lr = learning_rate)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min',
    #                                                        patience = 3,
    #                                                        threshold=0.001,
    #                                                        verbose = True,
    #                                                        factor  = 0.5)
    #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=4,
    #                                                      verbose=True)

    for epoch in range(epochs):
        train_loss, train_metric = fit_epoch(model, train_loader, criterion, opt, metric)


        val_loss, val_metric = eval_epoch(model, val_loader, criterion, metric)
        # scheduler.step(val_loss)
        history.append((train_loss, train_metric, val_loss, val_metric))



        # if val_metric > best_metric:
        #     best_metric = val_metric
        #     best_model_wts = model.state_dict()
        #     best_epoch = epoch

        print(f"\nEpoch {epoch:03d} train_loss: {train_loss:0.4f} \
    val_loss {val_loss:0.4f} train_metric {train_metric:0.4f} val_metric {val_metric:0.4f}")



    return history

In [21]:
model = UNet(1, len(DatasetConfig.USED_CLASSES)).to(device)
metric = MulticlassF1Score(num_classes=len(DatasetConfig.USED_CLASSES), average="macro").to(device)
optimiser = optim.Adam
criterion = dice_coef_loss

In [22]:
train_dataset = OrgansDataset('data/train', 224, augs = DatasetConfig.AUGS)
train_loader = DataLoader(train_dataset, batch_size=16)

val_dataset = OrgansDataset('data/val', 224)
val_loader = DataLoader(val_dataset, batch_size=16)

In [23]:
train(model, 100, train_loader, val_loader, optimiser, criterion, metric)

100%|██████████| 115/115 [00:10<00:00, 10.69it/s]
100%|██████████| 33/33 [00:01<00:00, 25.91it/s]



Epoch 000 train_loss: 0.8125     val_loss 0.7857 train_metric 0.1942 val_metric 0.2430


100%|██████████| 115/115 [00:10<00:00, 11.15it/s]
100%|██████████| 33/33 [00:01<00:00, 25.41it/s]



Epoch 001 train_loss: 0.7246     val_loss 0.7452 train_metric 0.3221 val_metric 0.3233


100%|██████████| 115/115 [00:10<00:00, 11.12it/s]
100%|██████████| 33/33 [00:01<00:00, 25.15it/s]



Epoch 002 train_loss: 0.6644     val_loss 0.8004 train_metric 0.4214 val_metric 0.2717


100%|██████████| 115/115 [00:10<00:00, 11.07it/s]
100%|██████████| 33/33 [00:01<00:00, 26.16it/s]



Epoch 003 train_loss: 0.6368     val_loss 0.6453 train_metric 0.4679 val_metric 0.4623


100%|██████████| 115/115 [00:10<00:00, 11.10it/s]
100%|██████████| 33/33 [00:01<00:00, 25.78it/s]



Epoch 004 train_loss: 0.6092     val_loss 0.6168 train_metric 0.5137 val_metric 0.5241


100%|██████████| 115/115 [00:10<00:00, 11.06it/s]
100%|██████████| 33/33 [00:01<00:00, 25.82it/s]



Epoch 005 train_loss: 0.5954     val_loss 0.6056 train_metric 0.5385 val_metric 0.5398


100%|██████████| 115/115 [00:10<00:00, 11.18it/s]
100%|██████████| 33/33 [00:01<00:00, 25.96it/s]



Epoch 006 train_loss: 0.5808     val_loss 0.6227 train_metric 0.5660 val_metric 0.5225


100%|██████████| 115/115 [00:10<00:00, 11.13it/s]
100%|██████████| 33/33 [00:01<00:00, 25.24it/s]



Epoch 007 train_loss: 0.5709     val_loss 0.6078 train_metric 0.5803 val_metric 0.5436


100%|██████████| 115/115 [00:10<00:00, 11.16it/s]
100%|██████████| 33/33 [00:01<00:00, 25.21it/s]



Epoch 008 train_loss: 0.5587     val_loss 0.6010 train_metric 0.6042 val_metric 0.5583


100%|██████████| 115/115 [00:10<00:00, 11.10it/s]
100%|██████████| 33/33 [00:01<00:00, 25.96it/s]



Epoch 009 train_loss: 0.5527     val_loss 0.5575 train_metric 0.6104 val_metric 0.6233


100%|██████████| 115/115 [00:10<00:00, 11.04it/s]
100%|██████████| 33/33 [00:01<00:00, 25.73it/s]



Epoch 010 train_loss: 0.5421     val_loss 0.5936 train_metric 0.6360 val_metric 0.5836


 35%|███▍      | 40/115 [00:03<00:06, 11.01it/s]