# Import packages

In [1]:
%pip install torch torchvision torchaudio
%pip install -r requirements.txt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import random
import argparse
from PIL import Image
from sklearn.model_selection import train_test_split
import albumentations as albu
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils as utils
import os
import cv2
import math
%pip install git+https://github.com/Po-Hsun-Su/pytorch-ssim.git
import pytorch_ssim
from torchinfo import summary
import ssl
ssl._create_default_https_context = ssl._create_unverified_context # for downloading pretrained encoder weights

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


  from .autonotebook import tqdm as notebook_tqdm


Collecting git+https://github.com/Po-Hsun-Su/pytorch-ssim.git
  Cloning https://github.com/Po-Hsun-Su/pytorch-ssim.git to /tmp/pip-req-build-vxqctjn2
  Running command git clone --filter=blob:none --quiet https://github.com/Po-Hsun-Su/pytorch-ssim.git /tmp/pip-req-build-vxqctjn2
  Resolved https://github.com/Po-Hsun-Su/pytorch-ssim.git to commit 3add4532d3f633316cba235da1c69e90f0dfb952
  Preparing metadata (setup.py) ... [?25ldone
[?25hNote: you may need to restart the kernel to use updated packages.


# Settings for training

In [2]:
SEED = 2024
GPU_ID = 6 # Select GPU to train and test on
EPOCHS = 100
LR = 1e-3
DECAY = 1e-3
OPTIM = 'adam'
L2REG = 1e-5
ENCODER = 'se_resnet152'
ETA = 1
PROB = 0.4

In [3]:
# Set device to GPU if available, else CPU
DEVICE = torch.device(f'cuda:{GPU_ID}' if torch.cuda.is_available() else 'cpu')
print(f"Device set to {DEVICE}.")

Device set to cuda:6.


# Helper functions

In [4]:
# Sets random seed for reproducibility (Default = 2024)
def set_random_seed(seed=2024):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")

set_random_seed(SEED)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

# Calculate the learning rate based on a linear warmup and a cosine decay.
def get_lr(epoch_num, warmup_epochs, total_epochs, init_lr, min_lr):
    if epoch_num < warmup_epochs:
        # Linear warmup
        lr = min_lr + (init_lr - min_lr) * epoch_num / warmup_epochs
    else:
        # Cosine decay
        decay_progress = (epoch_num - warmup_epochs) / (total_epochs - warmup_epochs)
        lr = min_lr + (init_lr - min_lr) * (1 + math.cos(math.pi * decay_progress)) / 2
    return lr

# Log the results of the training session
def log_results(filename, config, best_score, best_epoch, final_val_score):
    with open(filename, 'a') as file:
        file.write(f"Model Configuration and Training Settings:\n")
        for key, value in config.items():
            file.write(f"{key}: {value}\n")
        file.write(f"Best Validation F-Score: {best_score} (Epoch: {best_epoch + 1})\n")
        file.write(f"Final Validation F-Score: {final_val_score}\n")
        file.write("--------------------------------------------------\n")

Random seed set to 2024


# Dataset, augmentation, and preprocessing

In [5]:
# Custom dataset for images and corresponding masks
class NavImgDataset(Dataset):
    def __init__(self, img_paths, label_paths, preprocessing=None, augmentation=None):
        self.img_paths = img_paths
        self.label_paths = label_paths
        self.preprocessing = preprocessing
        self.augmentation = augmentation

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img = cv2.imread(self.img_paths[index])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = cv2.imread(self.label_paths[index], cv2.IMREAD_GRAYSCALE)
        label = np.expand_dims(label, axis=-1)
        label = label / 255.0
        if self.augmentation:
            sample = self.augmentation(image=img, mask=label)
            img, label = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=img, mask=label)
            img, label = sample['image'], sample['mask']
        return img, label

# Data augmentation (Flip, blur and noise)
def get_training_augmentation():
    train_transform = [
        # Flips
        albu.OneOf(
            [
                albu.HorizontalFlip(p=1),
                albu.VerticalFlip(p=1),
            ],
            p=0.5,
        ),
        # Blur and noise
        albu.OneOf(
            [
                albu.Defocus(p=1),
                albu.GaussNoise(p=1),
                albu.MedianBlur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
                albu.ZoomBlur(p=1)
            ],
            p=PROB
        ),
        # Padding to ensure each dimension is divisible by 32
        albu.PadIfNeeded(min_height=256, min_width=448, always_apply=True, border_mode=0),
        # Contrast enhancement
        albu.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True)
    ]
    return albu.Compose(train_transform)

# Same preprocessing for validation / test set
def get_validation_augmentation():
    test_transform = [
        albu.PadIfNeeded(min_height=256, min_width=448, always_apply=True, border_mode=0),
        albu.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True)
    ]
    return albu.Compose(test_transform)

def get_preprocessing(preprocessing_fn):    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def load_split_dataset(img_dir, label_dir,  preprocessing_fn, valid_size=.2):
    img_files = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
    label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.png')])

    # Split into training and validation sets (80-20 split)
    train_imgs, valid_imgs, train_labels, valid_labels = train_test_split(
        img_files, label_files, test_size=valid_size, random_state=2024)

    # Create dataset objects
    train_dataset = NavImgDataset(
        train_imgs, 
        train_labels, 
        preprocessing=get_preprocessing(preprocessing_fn), 
        augmentation=get_training_augmentation()
    )
    valid_dataset = NavImgDataset(
        valid_imgs, 
        valid_labels, 
        preprocessing=get_preprocessing(preprocessing_fn),
        augmentation=get_validation_augmentation()
    )

    return train_dataset, valid_dataset

# Cost function

In [6]:
# A combined loss function
class CombinedLoss(nn.Module):
    def __init__(self, _eta=ETA):
        super(CombinedLoss, self).__init__()
        self.mcc_loss = smp.losses.MCCLoss()
        self.focal_loss = smp.losses.FocalLoss(mode='binary', alpha=0.25, gamma=2.0)
        self.ssim_loss = pytorch_ssim.SSIM(window_size=11,size_average=True)
        self.eta = _eta
        self.__name__ = 'combined_loss'

    def forward(self, outputs, targets):
        mcc_loss = self.mcc_loss(outputs, targets)
        focal_loss = self.focal_loss(outputs, targets)
        ssim_loss = 1 - self.ssim_loss(outputs, targets)
        combined_loss = mcc_loss + self.eta * focal_loss + ssim_loss
        return combined_loss

# Training

In [7]:
# Path to training data
img_path = './train_dataset/img'
labels_path = './train_dataset/mask_img'

# Model parameters
ENCODER = ENCODER
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid'

# Constants
INIT_LR = LR  # Initial learning rate
NUM_EPOCHS = EPOCHS
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.1)  # Number of epochs for warmup
MIN_LR = INIT_LR * DECAY  # Minimum learning rate after decay

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

model.to(DEVICE)

# Show model details
summary(model, (1, 3, 256, 448))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Layer (type:depth-idx)                             Output Shape              Param #
Unet                                               [1, 1, 256, 448]          --
├─SENetEncoder: 1-1                                [1, 3, 256, 448]          --
│    └─Sequential: 2-1                             --                        --
│    │    └─Conv2d: 3-1                            [1, 64, 128, 224]         9,408
│    │    └─BatchNorm2d: 3-2                       [1, 64, 128, 224]         128
│    │    └─ReLU: 3-3                              [1, 64, 128, 224]         --
│    │    └─MaxPool2d: 3-4                         [1, 64, 64, 112]          --
│    └─Sequential: 2-2                             [1, 256, 64, 112]         --
│    │    └─SEResNetBottleneck: 3-5                [1, 256, 64, 112]         83,472
│    │    └─SEResNetBottleneck: 3-6                [1, 256, 64, 112]         78,864
│    │    └─SEResNetBottleneck: 3-7                [1, 256, 64, 112]         78,864
│    └─Sequential: 

In [8]:
# Load datasets
BATCH_SIZE = 32
train_dataset, valid_dataset = load_split_dataset(img_path, labels_path, preprocessing_fn, valid_size=.2)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Optimizer
if OPTIM == 'adam':
    optimizer = torch.optim.Adam([ 
        dict(params=model.parameters(), lr=INIT_LR, weight_decay=L2REG),
    ])
elif OPTIM == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=INIT_LR, weight_decay=L2REG)

# Loss and metrics
loss = CombinedLoss(_eta=ETA)
metrics = [utils.metrics.Fscore(beta=0.3 ** 0.5, threshold=0.5)]

# Training and validation 
train_epoch = utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

In [None]:
# Train model
max_score = 0
cnt = 0

# Initialize lists to store the F-scores for plotting later
train_f_scores = []
valid_f_scores = []

for i in range(NUM_EPOCHS):
    current_lr = get_lr(i, WARMUP_EPOCHS, NUM_EPOCHS, INIT_LR, MIN_LR)
    # Update optimizer with the current learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr
    
    print(f'\nEpoch: {i+1}, LR: {current_lr}')
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    # Record the F-scores
    train_f_scores.append(train_logs['fscore'])
    valid_f_scores.append(valid_logs['fscore'])

    # Save model on best validation score
    if max_score < valid_logs['fscore']:
        cnt = i
        max_score = valid_logs['fscore']
        model_path = f'./model[{model.__class__.__name__.lower()}]_focal_enc[{ENCODER}]_{NUM_EPOCHS}eps_initlr[{INIT_LR}]_decay[{DECAY}]_bs[{BATCH_SIZE}]_wd[{L2REG}]_eta[{ETA}]_pn[{PROB}]_mccloss_BEST.pth'
        torch.save(model, model_path)
        print(f'Model saved at {model_path}!')
# Final model save
final_model_path = f'./model[{model.__class__.__name__.lower()}]_focal_enc[{ENCODER}]_{NUM_EPOCHS}eps_initlr[{INIT_LR}]_decay[{DECAY}]_bs[{BATCH_SIZE}]_wd[{L2REG}]_eta[{ETA}]_pn[{PROB}]_mccloss.pth'
torch.save(model, final_model_path)
print('Training completed!')
print(f"Best Model @ {cnt+1} epochs; F-Score: {max_score}, for {model.__class__.__name__.lower()}")


Epoch: 1, LR: 1e-06
train: 100%|██████████| 108/108 [01:46<00:00,  1.02it/s, combined_loss - 2.212, fscore - 0.05212]
valid: 100%|██████████| 27/27 [00:11<00:00,  2.39it/s, combined_loss - 2.186, fscore - 0.07023]
Model saved at ./model[unet]_focal_enc[se_resnet152]_100eps_initlr[0.001]_decay[0.001]_bs[32]_wd[1e-05]_eta[1]_pn[0.4]_mccloss_BEST.pth!

Epoch: 2, LR: 0.00010090000000000001
train: 100%|██████████| 108/108 [01:47<00:00,  1.00it/s, combined_loss - 2.008, fscore - 0.2076]
valid: 100%|██████████| 27/27 [00:10<00:00,  2.46it/s, combined_loss - 1.878, fscore - 0.3125]
Model saved at ./model[unet]_focal_enc[se_resnet152]_100eps_initlr[0.001]_decay[0.001]_bs[32]_wd[1e-05]_eta[1]_pn[0.4]_mccloss_BEST.pth!

Epoch: 3, LR: 0.00020080000000000003
train: 100%|██████████| 108/108 [01:48<00:00,  1.01s/it, combined_loss - 1.548, fscore - 0.3892]
valid: 100%|██████████| 27/27 [00:11<00:00,  2.38it/s, combined_loss - 0.9512, fscore - 0.515] 
Model saved at ./model[unet]_focal_enc[se_resnet15

In [None]:
# Save train config
config = {
    'Model': model.__class__.__name__.lower(),
    'Encoder': ENCODER,
    'Initial Learning Rate': INIT_LR,
    'Decay': DECAY,
    'Number of Epochs': NUM_EPOCHS,
    'Optimizer': OPTIM,
    'L2 regularization': L2REG,
    'Batch Size': BATCH_SIZE,
    'Eta': ETA,
    'Loss': 'Focal + MCCLoss + SSIMLoss',
    'Probability for blurring and noise': PROB
}

# Find the final validation F-score
final_val_fscore = valid_f_scores[-1] if valid_f_scores else 0

# Log the results to a file
log_results(f"{model.__class__.__name__.lower()}_training_results.txt", config, max_score, cnt, final_val_fscore)

# Testing

In [5]:
class NavImgDataset(Dataset):
    def __init__(self, img_paths, preprocessing=None, augmentation=None):
        self.img_paths = img_paths
        self.preprocessing = preprocessing
        self.augmentation = augmentation

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img = cv2.imread(self.img_paths[index])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.augmentation:
            sample = self.augmentation(image=img, mask=img)
            img, _ = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=img, mask=img)
            img, _ = sample['image'], sample['mask']
        return img

def get_preprocessing(preprocessing_fn):    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def get_testing_augmentation():
    test_transform = [
        albu.PadIfNeeded(min_height=256, min_width=448, always_apply=True, border_mode=0),
        albu.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), always_apply=True)
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

In [6]:
# Model
ENCODER = 'se_resnet152'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid'

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=None, 
    classes=1, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
model_path = 'model[unet]_focal_enc[se_resnet152]_100eps_initlr[0.001]_decay[0.001]_bs[32]_wd[1e-05]_eta[1]_pn[0.4]_mccloss_BEST.pth'
model = torch.load(model_path)
model.to(DEVICE)
summary(model, (1, 3, 256, 448))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Layer (type:depth-idx)                             Output Shape              Param #
Unet                                               [1, 1, 256, 448]          --
├─SENetEncoder: 1-1                                [1, 3, 256, 448]          --
│    └─Sequential: 2-1                             --                        --
│    │    └─Conv2d: 3-1                            [1, 64, 128, 224]         9,408
│    │    └─BatchNorm2d: 3-2                       [1, 64, 128, 224]         128
│    │    └─ReLU: 3-3                              [1, 64, 128, 224]         --
│    │    └─MaxPool2d: 3-4                         [1, 64, 64, 112]          --
│    └─Sequential: 2-2                             [1, 256, 64, 112]         --
│    │    └─SEResNetBottleneck: 3-5                [1, 256, 64, 112]         83,472
│    │    └─SEResNetBottleneck: 3-6                [1, 256, 64, 112]         78,864
│    │    └─SEResNetBottleneck: 3-7                [1, 256, 64, 112]         78,864
│    └─Sequential: 

In [8]:
# Path to test data
img_dir = './test_dataset'
out_dir = f'./test_results_{model_path.split(".pth")[0]}'

# Check if the directory exists
if not os.path.exists(out_dir):
    # If the directory does not exist, create it
    os.makedirs(out_dir)
    print(f"Directory created at {out_dir}")
else:
    print(f"Directory already exists at {out_dir}")

# Load datasets
img_files = sorted([os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
test_dataset = NavImgDataset(
    img_files, 
    preprocessing=get_preprocessing(preprocessing_fn),
    augmentation=get_testing_augmentation()
)

model.eval()

for i, img_path in enumerate(img_files):
    # Fetch the image a single time
    image = test_dataset[i]
    
    # Convert the image numpy array to a tensor
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    
    # Perform prediction with the model
    with torch.no_grad():  # Ensure no gradients are calculated since we are only predicting
        pr_mask = model(x_tensor)
    
    # Processing the predicted mask for visualization
    pr_mask = pr_mask.squeeze().cpu().numpy()
    
    # Crop the mask to original dimensions
    original_h, original_w = 240, 428
    pad_top, pad_left = 8, 10
    pr_mask = pr_mask[pad_top:pad_top + original_h, pad_left:pad_left + original_w]    

    base_filename = os.path.splitext(os.path.basename(img_path))[0]
    print(f"Processed {base_filename}.")

    mask = Image.fromarray((pr_mask * 255).astype(np.uint8))
    mask.save(os.path.join(out_dir, f"{base_filename}.png"))
print(f"Done.")

Directory already exists at ./test_results_model[unet]_focal_enc[se_resnet152]_100eps_initlr[0.001]_decay[0.001]_bs[32]_wd[1e-05]_eta[1]_pn[0.4]_mccloss_BEST
Processed PRI_RI_2000000.
Processed PRI_RI_2000001.
Processed PRI_RI_2000002.
Processed PRI_RI_2000003.
Processed PRI_RI_2000004.
Processed PRI_RI_2000005.
Processed PRI_RI_2000006.
Processed PRI_RI_2000007.
Processed PRI_RI_2000008.
Processed PRI_RI_2000009.
Processed PRI_RI_2000010.
Processed PRI_RI_2000011.
Processed PRI_RI_2000012.
Processed PRI_RI_2000013.
Processed PRI_RI_2000014.
Processed PRI_RI_2000015.
Processed PRI_RI_2000016.
Processed PRI_RI_2000017.
Processed PRI_RI_2000018.
Processed PRI_RI_2000019.
Processed PRI_RI_2000020.
Processed PRI_RI_2000021.
Processed PRI_RI_2000022.
Processed PRI_RI_2000023.
Processed PRI_RI_2000024.
Processed PRI_RI_2000025.
Processed PRI_RI_2000026.
Processed PRI_RI_2000027.
Processed PRI_RI_2000028.
Processed PRI_RI_2000029.
Processed PRI_RI_2000030.
Processed PRI_RI_2000031.
Processed 