The challenge of segmenting images with limited data was unique due to the scarcity of training data and the abundance of unlabeled data. Initially, I built my solution using a mix vision transformer, but thanks to discussions on the forum, I switched my approach. To tackle the challenge of limited data, I decided to create patches from the original images and use various pretrained MaxViTs from timm with different image sizes. However, with limited time to test my hypothesis, I chose to stick with the configurations that worked best for my most successful model and used an ensemble of various models with different training and validation sets.

Given the limited data, I also chose to utilize an overfitting strategy, as the holdout and test sets were from the same distribution as the training set. Initially, I started with DiceLoss, but my MaxViT model was too strong for DiceLoss, so I decided to try Lovasz Loss, which worked well. To maintain the spatial and image ratio, I used padding, and for inference, I employed MONAI's sliding window inference.

## Imports 

In [1]:
import os
import random
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import PIL.Image as Image
from glob import glob
import cv2

import torch
from torch.utils.data import DataLoader

import segmentation_models_pytorch as smp

from data_preparation import *
from datasets import *
from training import *
from utils import *
from model import CreateModel
from pseudo_data import CreatePseudoSamples

import warnings
warnings.filterwarnings('ignore')

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

## Model configs

 The provided code imports several model configurations from the model_config module and organizes them into a dictionary called model_configs_dict. This dictionary maps descriptive keys to their respective model configuration objects, facilitating easier access and management of multiple model configurations for training purpose.

In [2]:
from model_config import (
    maxxvit_rmlp_small_rw_256_exp_1, maxxvitv2_rmlp_base_rw_384_patch_256_exp_2, 
    maxxvitv2_rmlp_base_rw_224_exp_3, maxxvitv2_rmlp_base_rw_384_exp_4, 
    maxvit_rmlp_base_rw_224_scse_exp_5, maxvit_rmlp_base_rw_384_exp_6, 
    maxxvitv2_rmlp_base_rw_384_pseudo_exp_7, maxxvitv2_rmlp_base_rw_384_exp_8, 
    maxxvit_rmlp_small_rw_256_pseudo_exp_9, maxxvitv2_rmlp_base_rw_224_scse_exp_10, 
    maxxvit_rmlp_small_rw_256_patch_224_pseudo_exp_11
)

#definne all model configurations
model_configs_dict = {
    'maxxvit_exp_1': maxxvit_rmlp_small_rw_256_exp_1,
    'maxxvitv2_exp_2': maxxvitv2_rmlp_base_rw_384_patch_256_exp_2,
    'maxxvitv2_exp_3': maxxvitv2_rmlp_base_rw_224_exp_3,
    'maxxvitv2_exp_4': maxxvitv2_rmlp_base_rw_384_exp_4,
    'maxvit_exp_5': maxvit_rmlp_base_rw_224_scse_exp_5,
    'maxvit_exp_6': maxvit_rmlp_base_rw_384_exp_6,
    'maxxvitv2_exp_7': maxxvitv2_rmlp_base_rw_384_pseudo_exp_7,
    'maxxvitv2_exp_8': maxxvitv2_rmlp_base_rw_384_exp_8,
    'maxxvit_exp_9': maxxvit_rmlp_small_rw_256_pseudo_exp_9,
    'maxxvitv2_exp_10': maxxvitv2_rmlp_base_rw_224_scse_exp_10,
    'maxxvit_exp_11': maxxvit_rmlp_small_rw_256_patch_224_pseudo_exp_11,
    }

 Define all the directory paths here and load the configurations for the model you want to train from the above model configurations. 
##### NOTE - Some model require pseudo samples for training. Therefore, train them in specific order.

In [3]:
# Paths and Configuration
train_data_dir = 'stranger-sections-2-train-data'
unlabeled_data_dir = 'stranger-sections-2-unlabeled-data/'
patches_data_dir = 'patches'

weights_dir = 'weights'
test_root_dir = 'stranger-sections-2-test-data/stranger-sections-2-test-data/'
pred_output_dir = 'model_predictions'

# trained model weights directory
model_checkpoints = 'model_checkpoints'

# Choose model configuration
config = model_configs_dict['maxxvit_exp_1']

In [4]:
# set the seed
seed_everything(config.seed)

 The next cell separates the data into train and validation sets. Depending on the model configuration, it generates pseudo labels for selected pseudo images from the unlabeled data provided.

In [8]:
#split training and validation images
train_list, valid_list = prepare_data(config, train_dir=train_data_dir, unlabeled_dir=unlabeled_data_dir)

 This block of code generates patches from the original images and their corresponding labels after padding both. Padding is used to ensure that the patches extracted from the images and labels maintain their spatial alignment and integrity.

In [None]:
#create train and validation patches
train_patch_names, valid_patch_names = create_patches(config, train_dir=train_data_dir, output_dir=patches_data_dir, train_list=train_list, val_list=valid_list)

 The code block creates datasets and dataloaders for training and validation using patches generated from the configured experiment and sets up dataloaders for batch processing during model training and validation.

In [10]:
# Careful with defining the directory path to patches for trainimg and validation
patch_root_dir = f'{patches_data_dir}/{config.exp_name}'

# Create datasets and dataloaders
train_dataset = StrangerSectionsDataset(root_dir=patch_root_dir, filenames_list=train_patch_names, transforms=config.train_transform, mode='train')
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)

val_dataset = StrangerSectionsDataset(root_dir=patch_root_dir, filenames_list=valid_patch_names, transforms=config.test_transform, mode='val')
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=False)

 Define model, loss, optimizer and schedulers here

In [8]:
#create model
model = CreateModel(config)

#define losses, optimizer and schedulers
criterion = smp.losses.LovaszLoss(mode='multiclass', per_image=True)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=config.min_lr, T_max=config.epochs * len(train_loader))
scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=config.patience)

Model parameters: 126_202_236


In [9]:
# delete some variables to free up memory
del train_list, valid_list, train_patch_names, valid_patch_names

 The following code block iterates through the training epochs, performing training and validation for each epoch.  The model's state is saved at the end of training, and the CUDA cache is cleared to manage memory.

In [None]:
# training and validation loop
for epoch in range(1, config.epochs + 1):
    start_time = time.time()
    print(f'Starting epoch: {epoch}')
    
    # Training and validation
    training_loss = train(model, train_loader, criterion, optimizer, config.device, scheduler)
    val_results = evaluate(model, val_loader, criterion, config.device, scheduler1)
    
    elapsed_time = time.time() - start_time
    
    valid_loss = val_results['loss'][0]
    valid_jaccard = val_results['jaccard'][0]
    
    print(f'Epoch: {epoch}, Time: {elapsed_time:.2f}')
    print(f'Train Loss: {training_loss:.4f}, Val Loss: {valid_loss:.4f}, Jaccard: {valid_jaccard:.4f}\n')
    
# Save the model
if not os.path.exists(weights_dir):
    os.makedirs(weights_dir, exist_ok=True)

model_path = os.path.join(weights_dir, f"unetplusplus_{config.exp_name}.pth")
torch.save(model.state_dict(), model_path)
print(f"Model for config '{config.exp_name}' saved to {model_path}")

# Clear the cache
torch.cuda.empty_cache()

## Inference

In [None]:
# get test images list
test_images_list = sorted(os.listdir(os.path.join(test_root_dir, 'image')))

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# create inference dataset and dataloader
transform = A.Compose([ToTensorV2()])
inference_dataset = Inference_dataset(test_root_dir, test_images_list, transform)
inference_loader = DataLoader(inference_dataset, batch_size=config.test_batch_size, num_workers=config.num_workers, shuffle=False)

In [None]:
# Initialize model and load trained weights
model = CreateModel(config)

model_path = os.path.join(weights_dir, f"unetplusplus_{config.exp_name}.pth")
model.load_state_dict(torch.load(model_path))

 The code defines a sliding window inference process using MONAI's SlidingWindowInferer, configured with all the parameters. Predicted masks are generated, cleared to remove noise using clear_predictions, and saved as .npy files in the specified prediction output directory.

In [None]:
# define sliding inferer from monai
inferer = SlidingWindowInferer(roi_size=config.roi_size, sw_batch_size=config.sw_batch_size, overlap=config.overlap, mode=config.mode, padding_mode=config.padding_mode)

model.eval()
model.to(config.device)

# generate predictions 
predictions = []
filenames = []

with torch.no_grad():
    for image, file_path in tqdm(inference_loader):
        image = image.to(config.device)
        outputs = inferer(inputs=image, network=model)
        
        predicted_masks = torch.argmax(outputs, dim=1)
        predictions.extend(predicted_masks.cpu().numpy())
        filenames.extend(file_path)

output_dir = os.path.join(pred_output_dir, f'{config.exp_name}')
if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok=True)

# clear images and save 
cleared_predictions = clear_predictions(predictions)
for filename, pred in zip(filenames, cleared_predictions):
    img_id = os.path.splitext(os.path.basename(filename))[0]
    output_name = os.path.join(output_dir, f"{img_id}_pred.npy")
    np.save(output_name, pred)

print(f"Prediction for {config.exp_name} saved to {output_dir}")

##### NOTE - The provided code loops over all the model configurations and trains them in a single run. However, this approach might lead to out-of-memory issues. Therefore, it would be better to train each model configuration separately.

##### NOTE - Scroll down to make ensemble submission. 

In [4]:
#def train_and_evaluate(config, train_dir=None, unlabeled_dir=None, patches_dir=None):
#    # Seed everything for reproducibility
#    seed_everything(config.seed)
#    
#    train_list, valid_list = prepare_data(config, train_dir=train_dir, unlabeled_dir=unlabeled_dir)
#
#    # Create patches for training and validation
#    train_patch_names, valid_patch_names = create_patches(config, train_dir=train_dir, output_dir=patches_dir, train_list=train_list, val_list=valid_list)
#
#    # Create datasets and dataloaders
#    train_dataset = StrangerSectionsDataset(f'patches/{config.exp_name}', train_patch_names, transforms=config.train_transform, mode='train')
#    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
#
#    val_dataset = StrangerSectionsDataset(f'patches/{config.exp_name}', valid_patch_names, transforms=config.test_transform, mode='val')
#    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=False)
#    
#    model = CreateModel(config)
#    
#    criterion = smp.losses.LovaszLoss(mode='multiclass', per_image=True)
#    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
#    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=config.min_lr, T_max=config.epochs * len(train_loader))
#    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=config.patience)
#
#    for epoch in range(1, config.epochs + 1):
#        start_time = time.time()
#        print(f'Starting epoch: {epoch}')
#        
#        # Training and validation
#        training_loss = train(model, train_loader, criterion, optimizer, config.device, scheduler)
#        val_results = evaluate(model, val_loader, criterion, config.device, scheduler1)
#        
#        elapsed_time = time.time() - start_time
#        
#        valid_loss = val_results['loss'][0]
#        valid_jaccard = val_results['jaccard'][0]
#        
#        print(f'Epoch: {epoch}, Time: {elapsed_time:.2f}')
#        print(f'Train Loss: {training_loss:.4f}, Val Loss: {valid_loss:.4f}, Jaccard: {valid_jaccard:.4f}\n')
#        
#
#    # Save the model
#    weights_dir = "weights"
#    if not os.path.exists(weights_dir):
#        os.makedirs(weights_dir, exist_ok=True)
#    
#    model_path = os.path.join(weights_dir, f"unetplusplus_{config.exp_name}.pth")
#    torch.save(model.state_dict(), model_path)
#    
#    # Explicitly delete the model and optimizer
#    del model
#    del optimizer
#    torch.cuda.empty_cache()

In [2]:
## Paths and Configuration
#train_data_dir = 'D:/CODE/competitions/StrangerSection2/stranger-sections-2-train-data/'
#unlabeled_data_dir = 'D:/CODE/competitions/StrangerSection2/stranger-sections-2-unlabeled-data/'
#patches_data_dir = 'patches'
#
#
## training and evaluation
#def main():
#
#    for exp_name, config in model_configs_dict.items():
#        try:
#            print(f"Training model configuration: {exp_name}\n")
#            train_and_evaluate(config, train_dir=train_data_dir, unlabelled_dir=unlabelled_data_dir, patches_dir=patches_data_dir)
#        except Exception as e:
#            print(f"Error occurred while training {exp_name}: {e}")
#        finally:
#            # Clearing CUDA cache
#            torch.cuda.empty_cache()
#
#if __name__ == '__main__':
#    main()

The following code block will help generate predictions for the trained models in model-checkpoints directory. 

NOTE - Code to generate predictions for test set

In [5]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Inference
def predict_all_models(config, test_root_dir=None, weights_path=None, output_dir=None):

    test_images_list = sorted(os.listdir(os.path.join(test_root_dir, 'image')))
    
    transform = A.Compose([ToTensorV2()])
    inference_dataset = Inference_dataset(test_root_dir, test_images_list[:2], transform)
    inference_loader = DataLoader(inference_dataset, batch_size=config.test_batch_size, num_workers=config.num_workers, shuffle=False)
    
    inferer = SlidingWindowInferer(roi_size=config.roi_size, sw_batch_size=config.sw_batch_size, overlap=config.overlap, mode=config.mode, padding_mode=config.padding_mode)

    model = CreateModel(config)
    
    model_path = os.path.join(weights_path, f"unetplusplus_{config.exp_name}.pth")
    model.load_state_dict(torch.load(model_path))
    model.eval()
    model.to(config.device)

    predictions = []
    filenames = []

    with torch.no_grad():
        for image, file_path in tqdm(inference_loader):
            image = image.to(config.device)
            outputs = inferer(inputs=image, network=model)
            
            predicted_masks = torch.argmax(outputs, dim=1)
            predictions.extend(predicted_masks.cpu().numpy())
            filenames.extend(file_path)

    output_dir = os.path.join(output_dir, f'{config.exp_name}')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    cleared_predictions = clear_predictions(predictions)
    for filename, pred in zip(filenames, cleared_predictions):
        img_id = os.path.splitext(os.path.basename(filename))[0]
        output_name = os.path.join(output_dir, f"{img_id}_pred.npy")
        np.save(output_name, pred)

In [None]:
weights_dir = 'weights'

# trained model weights directory
model_checkpoints = 'model_checkpoints'

test_root_dir = 'stranger-sections-2-test-data/stranger-sections-2-test-data/'
pred_output_dir = 'model_predictions'


for exp_name, config in model_configs_dict.items():
    try:
        print(f"Predicting with model configuration: {exp_name}")
        predict_all_models(config, test_root_dir, model_checkpoints, pred_output_dir)
    except Exception as e:
        print(f"Error occurred while predicting with {exp_name}: {e}")
    finally:
        # Clear CUDA cache to prevent out of memory errors
        torch.cuda.empty_cache()

## Ensemble Model Preds

 The code block loads all model predictions. The predictions are processes using an ensemble_predictions function, presumably combining them into a final prediction for each image by averaging the predictions from all models.

In [None]:
# load directory path to model predictions
all_model_preds_path = 'model_predictions'
ensemble_output_folder = 'submission'

preds_list = os.listdir(all_model_preds_path)

def load_sub(pred_path):
    label_list = os.listdir(pred_path)
    sub_masks = []
    filenames = []
    for name in label_list:
        fname = os.path.splitext(name)[0].split('_')[0]
        mask_path = os.path.join(pred_path, f'{fname}_pred.npy')
        mask = np.load(mask_path)
        sub_masks.append(mask)
        filenames.append(fname)
    return sub_masks, filenames

model_preds = []
pred_filenames = None
for pred_folder in preds_list:
    pred_path = os.path.join(all_model_preds_path, pred_folder)
    print(f"Loading predictions from: {pred_path}")
    masks, filename = load_sub(pred_path)
    model_preds.append(masks)
    if pred_filenames is None:
        pred_filenames = filename

ensemble_preds = ensemble_predictions(model_preds)

if not os.path.exists(ensemble_output_folder):
    os.makedirs(ensemble_output_folder, exist_ok=True)

for fname, pred in zip(pred_filenames, ensemble_preds):
    output_name = os.path.join(ensemble_output_folder, f"{fname}_pred.npy")
    np.save(output_name, pred)

print(f"\nFinal ensemble saved to {ensemble_output_folder}")