# Model

In [1]:
# Import comet_ml at the top of your file
from comet_ml import Experiment

# Create an experiment with your api key
experiment = Experiment(
    api_key="9vyoHvz6rwbLc09Gu2QbbFPAV",
    project_name="tumor-segmentation",
    workspace="prakhar-agarwal-byte",
)

COMET INFO: Couldn't find a Git repository in '/home/jovyan' nor in any parent directory. You can override where Comet is looking for a Git Patch by setting the configuration `COMET_GIT_DIRECTORY`
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/prakhar-agarwal-byte/tumor-segmentation/b0d2e993579e4fec8935be62676d34cb



In [2]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [3]:
class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)


In [4]:
def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

# Dataset

In [5]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

In [6]:
class MRIDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".tif", "_mask.tif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        
        return image, mask


# Utils

In [7]:
import torch
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
        train_dir,
        train_maskdir,
        val_dir,
        val_maskdir,
        batch_size,
        train_transform,
        val_transform,
        num_workers=4,
        pin_memory=True,
):
    train_ds = MRIDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = MRIDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def check_accuracy(loader, model, mode, epoch, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    experiment.log_metric(f"{mode}_dice_score", dice_score/len(loader), step=epoch)
    experiment.log_metric(f"{mode}_accuracy", num_correct/num_pixels*100, step=epoch)
    model.train()


def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

# Train

In [8]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS = 50
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "MRI Dataset/train_images"
TRAIN_MASK_DIR = "MRI Dataset/train_mask"
VAL_IMG_DIR = "MRI Dataset/val_images"
VAL_MASK_DIR = "MRI Dataset/val_mask"


# Report multiple hyperparameters using a dictionary:
hyper_params = {
    "learning_rate": LEARNING_RATE,
    "epochs": 50,
    "batch_size": BATCH_SIZE,
}
experiment.log_parameters(hyper_params)

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(train_loader, model, "training", epoch, device=DEVICE)
        check_accuracy(val_loader, model, "validation", epoch, device=DEVICE)
        
        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE
        )

In [None]:
main()

=> Loading checkpoint


100%|██████████| 60/60 [01:51<00:00,  1.85s/it, loss=0.0232]


=> Saving checkpoint
Got 250358442/250937344 with acc 99.77
Dice score: 0.8803378343582153
Got 6539253/6553600 with acc 99.78
Dice score: 0.8869966864585876


100%|██████████| 60/60 [01:56<00:00,  1.93s/it, loss=0.0207]


=> Saving checkpoint
Got 250414647/250937344 with acc 99.79
Dice score: 0.8969385027885437
Got 6539706/6553600 with acc 99.79
Dice score: 0.8952927589416504


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0176]


=> Saving checkpoint
Got 250398557/250937344 with acc 99.79
Dice score: 0.8888599276542664
Got 6539800/6553600 with acc 99.79
Dice score: 0.8900848627090454


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.015] 


=> Saving checkpoint
Got 250400483/250937344 with acc 99.79
Dice score: 0.8939995169639587
Got 6540295/6553600 with acc 99.80
Dice score: 0.9024801850318909


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0134]


=> Saving checkpoint
Got 250359535/250937344 with acc 99.77
Dice score: 0.8788880705833435
Got 6540014/6553600 with acc 99.79
Dice score: 0.893551230430603


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0125]


=> Saving checkpoint
Got 250442277/250937344 with acc 99.80
Dice score: 0.8997586965560913
Got 6540232/6553600 with acc 99.80
Dice score: 0.8968640565872192


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0111]


=> Saving checkpoint
Got 250299390/250937344 with acc 99.75
Dice score: 0.873737096786499
Got 6538084/6553600 with acc 99.76
Dice score: 0.8850079774856567


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0109]


=> Saving checkpoint
Got 250410401/250937344 with acc 99.79
Dice score: 0.8973494172096252
Got 6539688/6553600 with acc 99.79
Dice score: 0.8966012001037598


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0127] 


=> Saving checkpoint
Got 250452260/250937344 with acc 99.81
Dice score: 0.9054688215255737
Got 6540615/6553600 with acc 99.80
Dice score: 0.901900053024292


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0104] 


=> Saving checkpoint
Got 250447279/250937344 with acc 99.80
Dice score: 0.9010977745056152
Got 6540653/6553600 with acc 99.80
Dice score: 0.9002535343170166


100%|██████████| 60/60 [01:55<00:00,  1.93s/it, loss=0.0101] 


=> Saving checkpoint
Got 250449479/250937344 with acc 99.81
Dice score: 0.9038371443748474
Got 6539723/6553600 with acc 99.79
Dice score: 0.8917149305343628


 72%|███████▏  | 43/60 [01:23<00:32,  1.93s/it, loss=0.00954]

# Prediction

In [None]:
import shutil

In [None]:
def predict(img_path, img_mask_path, index, folder="results/"):
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
    model.eval()
    transformer = A.Compose(
            [
                A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
                A.Normalize(
                    mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0,
                ),
                ToTensorV2(),
            ],
        )
    img = np.array(Image.open(img_path).convert("RGB"))
    img_mask = torch.from_numpy(np.array(Image.open(img_mask_path).convert("L")))
    img_tensor = transformer(image=img)["image"].unsqueeze(0).to(device=DEVICE)
    # print(img_tensor.shape)
    with torch.no_grad():
        preds = torch.sigmoid(model(img_tensor))
        preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}pred_{index}.png"
        )
        torchvision.utils.save_image(
            img_tensor.cpu(), f"{folder}og_{index}.png"
        )
        preds = preds.squeeze(0).cpu().numpy()
        # print(preds.shape)
        # print(img_tensor.shape)
        # print(preds[0])
        for i in range(preds.shape[1]):
            for j in range(preds.shape[2]):
                if preds[0, i, j] == 1:
                    # print(i, j)
                    img_tensor[0,:, i,j] = torch.tensor((255, 0, 0))
        torchvision.utils.save_image(
            img_tensor, f"{folder}final_{index}.png"
        )
        shutil.copyfile(img_mask_path, f"{folder}img_mask_{index}.png")
    model.train()

In [None]:
predict("MRI Dataset/val_images/TCGA_DU_7010_19860307_31.tif", "MRI Dataset/val_mask/TCGA_DU_7010_19860307_31_mask.tif", 8)