In [2]:
import os
import segmentation_models_pytorch as smp
from monai.data import DataLoader, Dataset
import torch
from torchmetrics import JaccardIndex
from tqdm import tqdm
from custom_model import CustomModel
from transformers import SegformerConfig, SegformerForSemanticSegmentation
import numpy as np
import pickle

In [3]:
if os.path.exists("kfold_splits.pkl"):
    with open('kfold_splits.pkl', 'rb') as file:
    # Load train and validation data
        folds = pickle.load(file)
else:
    from preprocess import folds
    
if os.path.exists("test_set.pkl"):
    with open('test_set.pkl', 'rb') as file:
    # Load test data
        test_set = pickle.load(file)
else:
    from preprocess import test_set

In [4]:
os.makedirs('./output', exist_ok=True)
root_dir = './output'

In [5]:
gpu = "cuda"
loss = "Jaccard"
model_type = "custom"

In [7]:
if gpu == "cuda":
    device = torch.device("cuda:0")

if loss == "Dice":
    loss_function = smp.losses.DiceLoss(mode="binary", from_logits=False)
elif loss == "Jaccard":
    loss_function = smp.losses.JaccardLoss(mode="binary", from_logits=False)

metric = JaccardIndex(task="binary").to(device)

In [8]:
if model_type == "resnet":
    # Create the UNet model with the specified backbone
    model = smp.Unet(
        encoder_name="resnet34",
        in_channels=3,
        classes=1,
        activation="sigmoid",
        encoder_weights=None
    ).to(device)
elif model_type == "unet":
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True).to(device)
elif model_type == "segformer":
    config = SegformerConfig(num_labels=1)
    model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512", num_labels=1, ignore_mismatched_sizes=True).to(device)
elif model_type == "custom":
    model = CustomModel(
        channels = [64, 128, 256, 512],
        scale = [4, 2, 2, 2],
        num_blocks = [2, 2, 2, 2],
        num_heads = [2, 4, 8, 16],
        mlp_hidden = [256, 512, 1024, 2048],
        dropout = 0.0,                     # change to dropout rate=0.2 
        decoder_hidden = 256,
        attention=True
    ).to(device)
        

In [9]:
# x = torch.randn(1, 3, 128, 128)
# model = CustomModel(
#         channels = [64, 128, 256, 512],
#         scale = [4, 2, 2, 2],
#         num_blocks = [2, 2, 2, 2],
#         num_heads = [2, 4, 8, 16],
#         mlp_hidden = [256, 512, 1024, 2048],
#         dropout = 0.0,                     # change to dropout rate=0.2 
#         decoder_hidden = 256,
#     )
# l =model.encoder(x)
# for i in range(len(l)):
#     print(l[i].shape)
# l.reverse()
# model.decoder(l).shape

torch.Size([1, 64, 32, 32])
torch.Size([1, 128, 16, 16])
torch.Size([1, 256, 8, 8])
torch.Size([1, 512, 4, 4])


torch.Size([1, 256, 32, 32])

In [10]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total number of parameters: {total_params}")
    print(f"Number of trainable parameters: {trainable_params}")
count_parameters(model)

Total number of parameters: 9243231
Number of trainable parameters: 9243231


In [11]:
model.parameters

<bound method Module.parameters of CustomModel(
  (encoder): Encoder(
    (layers): ModuleList(
      (0): ModuleDict(
        (merge): PatchEmbedding(
          (conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
          (norm): Normalize(
            (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (transformers): EncoderBlock(
          (blocks): ModuleList(
            (0-1): 2 x ModuleDict(
              (self-attention): EncoderAttentionBlock(
                (attention): MultiheadAttention(
                  (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
                )
                (norm): Normalize(
                  (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                )
              )
              (mlp): MLPBlock(
                (dense): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
                

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

In [13]:
os.makedirs(os.path.join(root_dir, model_type), exist_ok=True)
out_dir = os.path.join(root_dir, model_type)

In [14]:
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []


num_CV = 5
max_epochs = 100
epoch_per_CV = max_epochs // num_CV

for i in range(num_CV):

    print("-" * 10)
    print(f"Fold {i}")
    train_ds = folds[i]["train"]
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
    val_ds = folds[i]["val"]
    val_loader = DataLoader(val_ds, batch_size=8, num_workers=4)

    if i < num_CV - 1:
        start_epoch = i * epoch_per_CV
        end_epoch = (i + 1) * epoch_per_CV
    else:
        start_epoch = i * epoch_per_CV
        end_epoch = max_epochs

    for epoch in range(start_epoch, end_epoch):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")

        # Training
        model.train()
        epoch_loss = 0
        step = 0

        for batch_data in tqdm(train_loader, desc="Training"):
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )

            optimizer.zero_grad()

            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                if model_type == "segformer":
                    outputs = torch.nn.functional.sigmoid(torch.nn.functional.interpolate(outputs.logits,128,mode="bilinear",align_corners=False))
                loss = loss_function(outputs, labels)
    
            if torch.isnan(loss) or torch.isnan(torch.flatten(outputs)).any().item():
                print(outputs)
                print(torch.isnan(torch.flatten(outputs)).any())
                print(labels)
                print(loss)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()
            
        # update learning rate
        lr_scheduler.step()
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        # Validation
        if (epoch + 1) % val_interval == 0:
            model.eval()
            val_iou = 0.0
            num_batches = len(val_loader)

            with torch.no_grad():
                for val_data in tqdm(val_loader, desc="Validation"):
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )

                    val_outputs = model(val_inputs)
                    if model_type == "segformer":
                        val_outputs = torch.nn.functional.sigmoid(torch.nn.functional.interpolate(val_outputs.logits,128,mode="bilinear",align_corners=False))

                    # Calculate IoU
                    iou = metric(val_outputs, val_labels)
                    val_iou += iou.item()

            # Calculate average IoU across all batches
            average_val_iou = val_iou / num_batches
            metric_values.append(average_val_iou)

            if average_val_iou > best_metric:
                best_metric = average_val_iou
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(out_dir, "best_metric_model.pth"))
                print("saved new best metric model")

            print(
                f"current epoch: {epoch + 1} current mean IoU: {average_val_iou:.4f}"
                f"\nbest mean IoU: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"  
            )


----------
Fold 0
----------
epoch 1/100


Training:   0%|          | 0/97 [00:00<?, ?it/s]

Training:   0%|          | 0/97 [00:26<?, ?it/s]


RuntimeError: Input type (struct c10::Half) and bias type (float) should be the same

In [103]:
epoch_loss_array = np.array(epoch_loss_values)
metric_array = np.array(metric_values)
np.savetxt(os.path.join(out_dir, 'epoch_loss_array.csv'), epoch_loss_array, delimiter=',')
np.savetxt(os.path.join(out_dir, 'metric_array.csv'), metric_array, delimiter=',')

In [25]:
model.named_parameters 

<bound method Module.named_parameters of CustomModel(
  (encoder): Encoder(
    (layers): ModuleList(
      (0): ModuleDict(
        (merge): PatchEmbedding(
          (conv): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
          (norm): Normalize(
            (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (transformers): EncoderBlock(
          (blocks): ModuleList(
            (0-1): 2 x ModuleDict(
              (self-attention): AttentionBlock(
                (attention): MultiheadAttention(
                  (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
                )
                (norm): Normalize(
                  (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                )
              )
              (mlp): MLPBlock(
                (dense): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
                (

In [None]:
# inference on test set
model = torch.load(os.path.join(out_dir, "best_metric_model.pth"))  