In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
!pip install -q monai

In [2]:
import os
import sys
import time
import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader

import monai
from monai.data import Dataset, CacheDataset, list_data_collate, decollate_batch
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    RandCropByPosNegLabeld,
    RandRotate90d,
    RandShiftIntensityd,
    ToTensord,
    CropForegroundd,
    ResizeWithPadOrCropd,
    Spacingd,
    RandFlipd,
    RandScaleIntensityd,
    RandAffined,
    Rand3DElasticd,
    RandGaussianNoised,
    RandAdjustContrastd,
    Activations,
    AsDiscrete,
    SqueezeDimd,
    ToNumpyd
)
from monai.networks.nets import UNet, SegResNet
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.visualize import plot_2d_or_3d_image
from monai.utils import set_determinism
from monai.networks.utils import one_hot
from monai.config import print_config


# Set deterministic training for reproducibility
set_determinism(seed=42)

# Define paths
data_dir = "/kaggle/input/mri-segmentation/MRI_dataset"  # Update with your data path
train_output_dir = "/kaggle/working/"
os.makedirs(train_output_dir, exist_ok=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}"

Using device: cuda


In [3]:
# Define data loaders
def get_data_dicts(data_dir):
    """Create data dictionaries with paths to images and labels"""
    train_files = []
    val_files = []
    
    # Assuming your data is organized as:
    # data_dir/
    #   - raw_images/
    #       - img1.nii.gz
    #       - img2.nii.gz
    #       ...
    #   - labels/
    #       - img1_label.nii.gz
    #       - img2_label.nii.gz
    #       ...
    
    # Get image and label file paths
    all_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii")))
    all_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii")))
    
    # Split into training (80%) and validation (20%)
    num_total = len(all_images)
    num_train = int(num_total * 0.8)
    
    # Create dictionaries for training data
    for img_path, label_path in zip(all_images[:num_train], all_labels[:num_train]):
        train_files.append({
            "image": img_path,
            "label": label_path
        })
    
    # Create dictionaries for validation data
    for img_path, label_path in zip(all_images[num_train:], all_labels[num_train:]):
        val_files.append({
            "image": img_path,
            "label": label_path
        })
    
    print(f"Training samples: {len(train_files)}, Validation samples: {len(val_files)}")
    return train_files, val_files

In [4]:
# Define transforms
def get_transforms():
    """Define preprocessing and augmentation transforms"""
    # Training transforms
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),  # Ensure channel first for both
        Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ScaleIntensityd(keys=["image"]),  # Scale intensity for the RAW image
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        RandScaleIntensityd(keys=["image"], factors=0.1, prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandAffined(
            keys=["image", "label"],
            prob=0.5,
            rotate_range=(0.05, 0.05, 0.05),
            scale_range=(0.1, 0.1, 0.1),
            mode=("bilinear", "nearest"),
        ),
        Rand3DElasticd(
        keys=["image", "label"],
        prob=0.3,
        sigma_range=(5, 8),
        magnitude_range=(50, 100),
        mode=("bilinear", "nearest"),
        ),
        RandAdjustContrastd(keys=["image"], prob=0.3, gamma=(0.7, 1.5)),
        RandGaussianNoised(keys=["image"], prob=0.3, mean=0, std=0.05),
        ToTensord(keys=["image", "label"]),
    ])
    
    # Validation transforms (no augmentation)
    reg_size = (197, 233, 189)
    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=reg_size),
        ScaleIntensityd(keys=["image"]),
        ToTensord(keys=["image", "label"]),
    ])
    
    return train_transforms, val_transforms

In [5]:
# Create datasets and dataloaders
def get_dataloaders(train_files, val_files, train_transforms, val_transforms, cache=False):
    """Create training and validation dataloaders"""
    # Create datasets
    if cache:
        train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0)
        val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
    else:
        train_ds = Dataset(data=train_files, transform=train_transforms)
        val_ds = Dataset(data=val_files, transform=val_transforms)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=torch.cuda.is_available(),
        collate_fn=list_data_collate
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=4,
        shuffle=False,
        num_workers=4,
        pin_memory=torch.cuda.is_available(),
        collate_fn=list_data_collate
    )
    
    return train_loader, val_loader

In [6]:
# Model definition
def get_model():
    """Create a SegResNet model with 1 input channel (RAW) and 3 output channels (CSF, GM, WM)"""
    model = SegResNet(
        blocks_down=[1, 2, 2, 4],
        blocks_up=[1, 1, 1],
        init_filters=16,
        in_channels=1,  # 1 channel for RAW MRI input
        out_channels=4,  # 4 channels for background, CSF, GM, WM
        dropout_prob=0.2,
    )
    
    return model

In [7]:
# Train function
def train(model, train_loader, val_loader, max_epochs=300, val_interval=5):
    """Train the model"""
    # Loss function and optimizer
    loss_function = DiceLoss(to_onehot_y=True, sigmoid=True)  # No onehot needed since labels are already in channels
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    
    # Dice metric for validation
    dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
    
    # Best metric tracking
    best_metric = -1
    best_metric_epoch = -1
    
    # Initialize lists for plotting metrics
    epoch_loss_values = []
    metric_values = []
    lr_values = []
    
    # Training loop
    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        
        # Training steps
        train_bar = tqdm(train_loader)
        for batch_data in train_bar:
            step += 1
            inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            
            #num_classes = 3 # (CSF, GM, WM)
            #labels = one_hot(labels, num_classes).to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            train_bar.set_description(f"Train loss: {loss.item():.4f}")
        
        # Update learning rate
        scheduler.step()
        lr_values.append(scheduler.get_last_lr()[0])
        
        # Calculate average loss for the epoch
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        
        # Validation
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_loss = 0
                val_steps = 0
                for val_data in tqdm(val_loader, desc="Validation"):
                    val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
            
                    #num_classes = 3 # (CSF, GM, WM)
                    #val_labels = one_hot(val_labels, num_classes).to(device)
                    
                    # Sliding window inference for larger volumes
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4

                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                    
                    # Calculate validation loss
                    val_loss += loss_function(val_outputs, val_labels).item()
                    
                    val_steps += 1
                    
                    # Calculate Dice score
                    val_outputs_list = decollate_batch(val_outputs)
                    val_labels_list = decollate_batch(val_labels)
                    val_outputs_convert = [Activations(sigmoid=True)(i) for i in val_outputs_list]
                    val_outputs_convert = [AsDiscrete(threshold=0.5)(i) for i in val_outputs_convert]
                    dice_metric(y_pred=val_outputs_convert, y=val_labels_list)
                    
                    # Plot images for visualization
                    if epoch + 1 == max_epochs and val_steps == 1:
                        # Input image (middle slice)
                        middle_slice = val_inputs.shape[4] // 2
                        plt.figure("Visualization", (18, 6))
                        plt.subplot(1, 5, 1)
                        plt.title("Input MRI")
                        plt.imshow(val_inputs[0, 0, :, :, middle_slice].cpu().numpy(), cmap="gray")
                        
                        # Ground truth labels (CSF, GM, WM)
                        plt.subplot(1, 5, 2)
                        plt.title("CSF (Ground Truth)")
                        plt.imshow(val_labels[0, 0, :, :, middle_slice].cpu().numpy(), cmap="jet")
                        
                        plt.subplot(1, 5, 3)
                        plt.title("GM (Ground Truth)")
                        plt.imshow(val_labels[0, 1, :, :, middle_slice].cpu().numpy(), cmap="jet")
                        
                        plt.subplot(1, 5, 4)
                        plt.title("WM (Ground Truth)")
                        plt.imshow(val_labels[0, 2, :, :, middle_slice].cpu().numpy(), cmap="jet")
                        
                        # Predicted segmentation (overlay)
                        pred = torch.zeros_like(val_inputs[0, 0])
                        for c in range(3):
                            pred[val_outputs[0, c] > 0.5] = c + 1
                        
                        plt.subplot(1, 5, 5)
                        plt.title("Prediction (RGB)")
                        plt.imshow(pred[:, :, middle_slice].cpu().numpy(), cmap="viridis")
                        
                        plt.tight_layout()
                        plt.savefig(os.path.join(train_output_dir, f"visualization_epoch_{epoch+1}.png"))
                        plt.close()
                
                # Calculate average validation loss and Dice scores
                avg_val_loss = val_loss / val_steps
                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                
                metric_values.append(metric)
                
                # Log validation metrics
                print(f"Validation loss: {avg_val_loss:.4f}, Dice score: {metric:.4f}")
                
                # Save best model
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(train_output_dir, "best_model.pth"))
                    print(f"Saved new best model (Dice: {best_metric:.4f})")
    
    # Print best metric
    print(f"\nTraining completed. Best Dice: {best_metric:.4f} at epoch {best_metric_epoch}")
    
    # Plot training curves
    plt.figure("Train", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Dice Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("Epoch")
    plt.plot(x, y, label="Training Loss")
    plt.subplot(1, 2, 2)
    plt.title("Mean Dice Score")
    val_epochs = [val_interval * (i + 1) for i in range(len(metric_values))]
    plt.xlabel("Epoch")
    plt.plot(val_epochs, metric_values, label="Validation Dice")
    plt.legend()
    plt.savefig(os.path.join(train_output_dir, "training_curves.png"))
    
    # Plot learning rate
    plt.figure("Learning Rate", (6, 6))
    plt.title("Learning Rate Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Learning Rate")
    plt.plot(x, lr_values)
    plt.savefig(os.path.join(train_output_dir, "learning_rate.png"))
    
    return model, best_metric, best_metric_epoch

In [8]:
# Function to visualize class-wise Dice scores
def visualize_class_dice(model, val_loader, class_names=["CSF", "GM", "WM"]):
    """Visualize class-wise Dice scores"""
    model.eval()
    
    # Initialize class-wise Dice metrics
    dice_metrics = {class_name: DiceMetric(include_background=False, reduction="mean") 
                    for class_name in class_names}
    
    with torch.no_grad():
        for val_data in tqdm(val_loader, desc="Evaluating class-wise performance"):
            val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                
            # Sliding window inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4

            val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
            
            # Process predictions
            val_outputs_list = decollate_batch(val_outputs)
            val_labels_list = decollate_batch(val_labels)
            val_outputs_convert = [Activations(sigmoid=True)(i) for i in val_outputs_list]
            val_outputs_convert = [AsDiscrete(threshold=0.5)(i) for i in val_outputs_convert]
            
            # Calculate class-wise Dice scores
            for c, class_name in enumerate(class_names):
                # Extract single class channel for both prediction and ground truth
                class_outputs = [x[:, c:c+1] for x in val_outputs_convert]
                class_labels = [x[:, c:c+1] for x in val_labels_list]
                dice_metrics[class_name](y_pred=class_outputs, y=class_labels)
    
    # Aggregate and plot class-wise dice scores
    class_dice_scores = [dice_metrics[class_name].aggregate().item() for class_name in class_names]
    
    plt.figure("Class-wise Dice Scores", (10, 6))
    plt.bar(class_names, class_dice_scores)
    plt.title("Dice Score by Brain Structure")
    plt.xlabel("Structure")
    plt.ylabel("Dice Score")
    for i, v in enumerate(class_dice_scores):
        plt.text(i, v + 0.01, f"{v:.3f}", ha='center')
    plt.ylim(0, 1.0)
    plt.tight_layout()
    plt.savefig(os.path.join(train_output_dir, "class_dice_scores.png"))
    
    return class_dice_scores

In [None]:
# Main function to run the training pipeline
def main():
    # Get data
    train_files, val_files = get_data_dicts(data_dir)
    
    # Get transforms
    train_transforms, val_transforms = get_transforms()
    
    # Get dataloaders
    train_loader, val_loader = get_dataloaders(train_files, val_files, train_transforms, val_transforms, cache=True)
    
    # Get model
    model = get_model()  # 1 input channel (RAW), 3 output channels (CSF, GM, WM)
    model = model.to(device)
    
    # Train model
    model, best_metric, best_metric_epoch = train(model, train_loader, val_loader, max_epochs=200, val_interval=10)
    
    # Save final model
    torch.save(model.state_dict(), os.path.join(train_output_dir, "final_model.pth"))
    
    # Visualize class-wise dice scores
    print("Evaluating class-wise performance...")
    class_dice_scores = visualize_class_dice(model, val_loader, class_names=["CSF", "GM", "WM"])
    
    # Create training progress visualization
    print("Generating training progress plots...")
    
    # Example: Create a visualization of sample predictions at different training stages
    # Load the best model
    best_model = get_model().to(device)
    best_model.load_state_dict(torch.load(os.path.join(train_output_dir, "best_model.pth")))
    best_model.eval()
    
    # Visualize predictions on a validation sample
    with torch.no_grad():
        for val_data in val_loader:
            val_inputs = val_data["image"].to(device)
            val_labels = val_data["label"].to(device)
                
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, best_model)
            
            # Create final visualization
            middle_slice = val_inputs.shape[4] // 2
            
            # Create a figure with 3D segmentation results
            fig, axes = plt.subplots(2, 4, figsize=(20, 10))
            
            # Original MRI
            axes[0, 0].set_title("Original MRI")
            axes[0, 0].imshow(val_inputs[0, 0, :, :, middle_slice].cpu().numpy(), cmap="gray")
            
            # Ground truth segmentation - Each tissue type
            axes[0, 1].set_title("CSF (Ground Truth)")
            axes[0, 1].imshow(val_labels[0, 0, :, :, middle_slice].cpu().numpy(), cmap="jet")
            
            axes[0, 2].set_title("GM (Ground Truth)")
            axes[0, 2].imshow(val_labels[0, 1, :, :, middle_slice].cpu().numpy(), cmap="jet")
            
            axes[0, 3].set_title("WM (Ground Truth)")
            axes[0, 3].imshow(val_labels[0, 2, :, :, middle_slice].cpu().numpy(), cmap="jet")
            
            # Predicted segmentation - Each tissue type
            axes[1, 1].set_title("CSF (Predicted)")
            axes[1, 1].imshow((val_outputs[0, 0, :, :, middle_slice] > 0.5).cpu().numpy(), cmap="jet")
            
            axes[1, 2].set_title("GM (Predicted)")
            axes[1, 2].imshow((val_outputs[0, 1, :, :, middle_slice] > 0.5).cpu().numpy(), cmap="jet")
            
            axes[1, 3].set_title("WM (Predicted)")
            axes[1, 3].imshow((val_outputs[0, 2, :, :, middle_slice] > 0.5).cpu().numpy(), cmap="jet")
            
            # Combine predictions into a single color-coded image
            combined_pred = np.zeros((val_outputs.shape[2], val_outputs.shape[3]))
            for c in range(3):
                combined_pred[val_outputs[0, c, :, :, middle_slice].cpu().numpy() > 0.5] = c + 1
            
            axes[1, 0].set_title("Combined Prediction")
            axes[1, 0].imshow(combined_pred, cmap="viridis")
            
            plt.tight_layout()
            plt.savefig(os.path.join(train_output_dir, "final_segmentation_results.png"))
            plt.close()
            
            # Create an overlay visualization
            plt.figure("Segmentation Overlay", figsize=(10, 10))
            plt.imshow(val_inputs[0, 0, :, :, middle_slice].cpu().numpy(), cmap="gray")
            
            # Create an alpha mask for overlay
            mask = np.zeros((val_outputs.shape[2], val_outputs.shape[3], 4))
            colors = [(1, 0, 0, 0.5), (0, 1, 0, 0.5), (0, 0, 1, 0.5)]  # RGBA colors for CSF, GM, WM
            
            for c in range(3):
                pred_mask = val_outputs[0, c, :, :, middle_slice].cpu().numpy() > 0.5
                for i in range(4):  # RGBA channels
                    mask[pred_mask, i] = colors[c][i]
            
            plt.imshow(mask)
            plt.title("Segmentation Overlay (Red: CSF, Green: GM, Blue: WM)")
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(os.path.join(train_output_dir, "segmentation_overlay.png"))
            
            break  # Just process one sample for visualization

    print(f"Training and evaluation completed. All results saved to {train_output_dir}")

if __name__ == "__main__":
    main()



Training samples: 108, Validation samples: 27


Loading dataset: 100%|██████████| 108/108 [00:19<00:00,  5.67it/s]
Loading dataset: 100%|██████████| 27/27 [00:06<00:00,  4.11it/s]



Epoch 1/200


Train loss: 0.6132: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 1 average loss: 0.6121

Epoch 2/200


Train loss: 0.5590: 100%|██████████| 108/108 [02:03<00:00,  1.15s/it]


Epoch 2 average loss: 0.5775

Epoch 3/200


Train loss: 0.5439: 100%|██████████| 108/108 [02:03<00:00,  1.15s/it]


Epoch 3 average loss: 0.5591

Epoch 4/200


Train loss: 0.5228: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 4 average loss: 0.5471

Epoch 5/200


Train loss: 0.5328: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 5 average loss: 0.5341

Epoch 6/200


Train loss: 0.4977: 100%|██████████| 108/108 [02:04<00:00,  1.16s/it]


Epoch 6 average loss: 0.5229

Epoch 7/200


Train loss: 0.4898: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 7 average loss: 0.5113

Epoch 8/200


Train loss: 0.4978: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 8 average loss: 0.4984

Epoch 9/200


Train loss: 0.6744: 100%|██████████| 108/108 [02:04<00:00,  1.16s/it]


Epoch 9 average loss: 0.4849

Epoch 10/200


Train loss: 0.5256: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 10 average loss: 0.4709


Validation: 100%|██████████| 7/7 [00:31<00:00,  4.52s/it]


Validation loss: 0.4903, Dice score: 0.5608
Saved new best model (Dice: 0.5608)

Epoch 11/200


Train loss: 0.4371: 100%|██████████| 108/108 [02:03<00:00,  1.14s/it]


Epoch 11 average loss: 0.4611

Epoch 12/200


Train loss: 0.4958: 100%|██████████| 108/108 [02:03<00:00,  1.14s/it]


Epoch 12 average loss: 0.4432

Epoch 13/200


Train loss: 0.3832: 100%|██████████| 108/108 [02:04<00:00,  1.15s/it]


Epoch 13 average loss: 0.4288

Epoch 14/200


Train loss: 0.3714: 100%|██████████| 108/108 [02:02<00:00,  1.14s/it]


Epoch 14 average loss: 0.4115

Epoch 15/200


Train loss: 0.3416: 100%|██████████| 108/108 [02:06<00:00,  1.17s/it]


Epoch 15 average loss: 0.4039

Epoch 16/200


Train loss: 0.3933: 100%|██████████| 108/108 [02:08<00:00,  1.19s/it]


Epoch 16 average loss: 0.3836

Epoch 17/200


Train loss: 0.3691: 100%|██████████| 108/108 [02:17<00:00,  1.28s/it]


Epoch 17 average loss: 0.3697

Epoch 18/200


Train loss: 0.3458: 100%|██████████| 108/108 [02:15<00:00,  1.25s/it]


Epoch 18 average loss: 0.3540

Epoch 19/200


Train loss: 0.3788: 100%|██████████| 108/108 [02:11<00:00,  1.22s/it]


Epoch 19 average loss: 0.3370

Epoch 20/200


Train loss: 0.3219: 100%|██████████| 108/108 [02:09<00:00,  1.20s/it]


Epoch 20 average loss: 0.3282


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.25s/it]


Validation loss: 0.3309, Dice score: 0.8003
Saved new best model (Dice: 0.8003)

Epoch 21/200


Train loss: 0.2608: 100%|██████████| 108/108 [02:08<00:00,  1.19s/it]


Epoch 21 average loss: 0.3116

Epoch 22/200


Train loss: 0.2476: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 22 average loss: 0.3017

Epoch 23/200


Train loss: 0.2305: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 23 average loss: 0.2933

Epoch 24/200


Train loss: 0.2767: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 24 average loss: 0.2819

Epoch 25/200


Train loss: 0.2972: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 25 average loss: 0.2732

Epoch 26/200


Train loss: 0.2069: 100%|██████████| 108/108 [02:24<00:00,  1.34s/it]


Epoch 26 average loss: 0.2617

Epoch 27/200


Train loss: 0.2288: 100%|██████████| 108/108 [02:24<00:00,  1.34s/it]


Epoch 27 average loss: 0.2485

Epoch 28/200


Train loss: 0.2332: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 28 average loss: 0.2495

Epoch 29/200


Train loss: 0.2127: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 29 average loss: 0.2409

Epoch 30/200


Train loss: 0.1771: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 30 average loss: 0.2328


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.33s/it]


Validation loss: 0.2394, Dice score: 0.8109
Saved new best model (Dice: 0.8109)

Epoch 31/200


Train loss: 0.2621: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 31 average loss: 0.2260

Epoch 32/200


Train loss: 0.1931: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 32 average loss: 0.2232

Epoch 33/200


Train loss: 0.1780: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 33 average loss: 0.2144

Epoch 34/200


Train loss: 0.1942: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 34 average loss: 0.2067

Epoch 35/200


Train loss: 0.1626: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 35 average loss: 0.2056

Epoch 36/200


Train loss: 0.1383: 100%|██████████| 108/108 [02:13<00:00,  1.23s/it]


Epoch 36 average loss: 0.2006

Epoch 37/200


Train loss: 0.1328: 100%|██████████| 108/108 [02:12<00:00,  1.22s/it]


Epoch 37 average loss: 0.1949

Epoch 38/200


Train loss: 0.1932: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 38 average loss: 0.1959

Epoch 39/200


Train loss: 0.1338: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 39 average loss: 0.1881

Epoch 40/200


Train loss: 0.1347: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 40 average loss: 0.1897


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.28s/it]


Validation loss: 0.1842, Dice score: 0.8284
Saved new best model (Dice: 0.8284)

Epoch 41/200


Train loss: 0.1659: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 41 average loss: 0.1885

Epoch 42/200


Train loss: 0.1339: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 42 average loss: 0.1885

Epoch 43/200


Train loss: 0.1595: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 43 average loss: 0.1767

Epoch 44/200


Train loss: 0.1554: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 44 average loss: 0.1767

Epoch 45/200


Train loss: 0.1618: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 45 average loss: 0.1751

Epoch 46/200


Train loss: 0.1525: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 46 average loss: 0.1685

Epoch 47/200


Train loss: 0.1303: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 47 average loss: 0.1707

Epoch 48/200


Train loss: 0.1173: 100%|██████████| 108/108 [02:24<00:00,  1.34s/it]


Epoch 48 average loss: 0.1698

Epoch 49/200


Train loss: 0.1284: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 49 average loss: 0.1668

Epoch 50/200


Train loss: 0.1829: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 50 average loss: 0.1614


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.32s/it]


Validation loss: 0.1576, Dice score: 0.8383
Saved new best model (Dice: 0.8383)

Epoch 51/200


Train loss: 0.2258: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 51 average loss: 0.1591

Epoch 52/200


Train loss: 0.2067: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 52 average loss: 0.1662

Epoch 53/200


Train loss: 0.1490: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 53 average loss: 0.1630

Epoch 54/200


Train loss: 0.1219: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 54 average loss: 0.1626

Epoch 55/200


Train loss: 0.1249: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 55 average loss: 0.1569

Epoch 56/200


Train loss: 0.0895: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 56 average loss: 0.1608

Epoch 57/200


Train loss: 0.1233: 100%|██████████| 108/108 [02:17<00:00,  1.28s/it]


Epoch 57 average loss: 0.1622

Epoch 58/200


Train loss: 0.1384: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 58 average loss: 0.1538

Epoch 59/200


Train loss: 0.1656: 100%|██████████| 108/108 [02:14<00:00,  1.25s/it]


Epoch 59 average loss: 0.1519

Epoch 60/200


Train loss: 0.0935: 100%|██████████| 108/108 [02:14<00:00,  1.24s/it]


Epoch 60 average loss: 0.1472


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.28s/it]


Validation loss: 0.1361, Dice score: 0.8531
Saved new best model (Dice: 0.8531)

Epoch 61/200


Train loss: 0.1577: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 61 average loss: 0.1507

Epoch 62/200


Train loss: 0.2764: 100%|██████████| 108/108 [02:11<00:00,  1.22s/it]


Epoch 62 average loss: 0.1501

Epoch 63/200


Train loss: 0.0901: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 63 average loss: 0.1468

Epoch 64/200


Train loss: 0.0961: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 64 average loss: 0.1503

Epoch 65/200


Train loss: 0.1212: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 65 average loss: 0.1481

Epoch 66/200


Train loss: 0.1519: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 66 average loss: 0.1479

Epoch 67/200


Train loss: 0.1573: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 67 average loss: 0.1483

Epoch 68/200


Train loss: 0.1341: 100%|██████████| 108/108 [02:15<00:00,  1.25s/it]


Epoch 68 average loss: 0.1440

Epoch 69/200


Train loss: 0.0915: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 69 average loss: 0.1436

Epoch 70/200


Train loss: 0.1577: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 70 average loss: 0.1418


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.28s/it]


Validation loss: 0.1525, Dice score: 0.8222

Epoch 71/200


Train loss: 0.1230: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 71 average loss: 0.1486

Epoch 72/200


Train loss: 0.1408: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 72 average loss: 0.1441

Epoch 73/200


Train loss: 0.1058: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 73 average loss: 0.1457

Epoch 74/200


Train loss: 0.1691: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 74 average loss: 0.1390

Epoch 75/200


Train loss: 0.1406: 100%|██████████| 108/108 [02:14<00:00,  1.25s/it]


Epoch 75 average loss: 0.1372

Epoch 76/200


Train loss: 0.1066: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 76 average loss: 0.1403

Epoch 77/200


Train loss: 0.1479: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 77 average loss: 0.1368

Epoch 78/200


Train loss: 0.0853: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 78 average loss: 0.1426

Epoch 79/200


Train loss: 0.0941: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 79 average loss: 0.1365

Epoch 80/200


Train loss: 0.1052: 100%|██████████| 108/108 [02:10<00:00,  1.21s/it]


Epoch 80 average loss: 0.1339


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.32s/it]


Validation loss: 0.1308, Dice score: 0.8466

Epoch 81/200


Train loss: 0.0893: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 81 average loss: 0.1398

Epoch 82/200


Train loss: 0.1231: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 82 average loss: 0.1367

Epoch 83/200


Train loss: 0.1033: 100%|██████████| 108/108 [02:09<00:00,  1.20s/it]


Epoch 83 average loss: 0.1389

Epoch 84/200


Train loss: 0.2927: 100%|██████████| 108/108 [02:10<00:00,  1.20s/it]


Epoch 84 average loss: 0.1374

Epoch 85/200


Train loss: 0.0842: 100%|██████████| 108/108 [02:08<00:00,  1.19s/it]


Epoch 85 average loss: 0.1391

Epoch 86/200


Train loss: 0.1578: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 86 average loss: 0.1325

Epoch 87/200


Train loss: 0.1273: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 87 average loss: 0.1324

Epoch 88/200


Train loss: 0.2317: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 88 average loss: 0.1334

Epoch 89/200


Train loss: 0.1975: 100%|██████████| 108/108 [02:23<00:00,  1.32s/it]


Epoch 89 average loss: 0.1346

Epoch 90/200


Train loss: 0.0943: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 90 average loss: 0.1320


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.28s/it]


Validation loss: 0.1235, Dice score: 0.8537
Saved new best model (Dice: 0.8537)

Epoch 91/200


Train loss: 0.2710: 100%|██████████| 108/108 [02:09<00:00,  1.20s/it]


Epoch 91 average loss: 0.1336

Epoch 92/200


Train loss: 0.1139: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 92 average loss: 0.1304

Epoch 93/200


Train loss: 0.1168: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 93 average loss: 0.1294

Epoch 94/200


Train loss: 0.1300: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 94 average loss: 0.1347

Epoch 95/200


Train loss: 0.1104: 100%|██████████| 108/108 [02:17<00:00,  1.28s/it]


Epoch 95 average loss: 0.1298

Epoch 96/200


Train loss: 0.0807: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 96 average loss: 0.1344

Epoch 97/200


Train loss: 0.0946: 100%|██████████| 108/108 [02:17<00:00,  1.28s/it]


Epoch 97 average loss: 0.1303

Epoch 98/200


Train loss: 0.0853: 100%|██████████| 108/108 [02:15<00:00,  1.26s/it]


Epoch 98 average loss: 0.1294

Epoch 99/200


Train loss: 0.1185: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 99 average loss: 0.1294

Epoch 100/200


Train loss: 0.0725: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 100 average loss: 0.1286


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.26s/it]


Validation loss: 0.1302, Dice score: 0.8428

Epoch 101/200


Train loss: 0.0856: 100%|██████████| 108/108 [02:13<00:00,  1.23s/it]


Epoch 101 average loss: 0.1281

Epoch 102/200


Train loss: 0.1796: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 102 average loss: 0.1287

Epoch 103/200


Train loss: 0.1264: 100%|██████████| 108/108 [02:15<00:00,  1.26s/it]


Epoch 103 average loss: 0.1272

Epoch 104/200


Train loss: 0.0762: 100%|██████████| 108/108 [02:14<00:00,  1.25s/it]


Epoch 104 average loss: 0.1237

Epoch 105/200


Train loss: 0.1739: 100%|██████████| 108/108 [02:15<00:00,  1.25s/it]


Epoch 105 average loss: 0.1285

Epoch 106/200


Train loss: 0.0938: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 106 average loss: 0.1278

Epoch 107/200


Train loss: 0.0718: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 107 average loss: 0.1270

Epoch 108/200


Train loss: 0.0815: 100%|██████████| 108/108 [02:23<00:00,  1.32s/it]


Epoch 108 average loss: 0.1276

Epoch 109/200


Train loss: 0.1829: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 109 average loss: 0.1273

Epoch 110/200


Train loss: 0.0838: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 110 average loss: 0.1233


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.28s/it]


Validation loss: 0.1250, Dice score: 0.8488

Epoch 111/200


Train loss: 0.1389: 100%|██████████| 108/108 [02:25<00:00,  1.35s/it]


Epoch 111 average loss: 0.1257

Epoch 112/200


Train loss: 0.1204: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 112 average loss: 0.1256

Epoch 113/200


Train loss: 0.1171: 100%|██████████| 108/108 [02:24<00:00,  1.34s/it]


Epoch 113 average loss: 0.1225

Epoch 114/200


Train loss: 0.1163: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 114 average loss: 0.1246

Epoch 115/200


Train loss: 0.0846: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 115 average loss: 0.1226

Epoch 116/200


Train loss: 0.0753: 100%|██████████| 108/108 [02:14<00:00,  1.25s/it]


Epoch 116 average loss: 0.1237

Epoch 117/200


Train loss: 0.0859: 100%|██████████| 108/108 [02:09<00:00,  1.20s/it]


Epoch 117 average loss: 0.1247

Epoch 118/200


Train loss: 0.1502: 100%|██████████| 108/108 [02:09<00:00,  1.20s/it]


Epoch 118 average loss: 0.1258

Epoch 119/200


Train loss: 0.1080: 100%|██████████| 108/108 [02:07<00:00,  1.18s/it]


Epoch 119 average loss: 0.1243

Epoch 120/200


Train loss: 0.0683: 100%|██████████| 108/108 [02:06<00:00,  1.17s/it]


Epoch 120 average loss: 0.1210


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.29s/it]


Validation loss: 0.1172, Dice score: 0.8573
Saved new best model (Dice: 0.8573)

Epoch 121/200


Train loss: 0.0920: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 121 average loss: 0.1221

Epoch 122/200


Train loss: 0.2441: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 122 average loss: 0.1237

Epoch 123/200


Train loss: 0.0824: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 123 average loss: 0.1230

Epoch 124/200


Train loss: 0.1268: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 124 average loss: 0.1229

Epoch 125/200


Train loss: 0.0930: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 125 average loss: 0.1228

Epoch 126/200


Train loss: 0.1084: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 126 average loss: 0.1209

Epoch 127/200


Train loss: 0.1393: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 127 average loss: 0.1256

Epoch 128/200


Train loss: 0.1092: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 128 average loss: 0.1206

Epoch 129/200


Train loss: 0.1435: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 129 average loss: 0.1193

Epoch 130/200


Train loss: 0.0924: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 130 average loss: 0.1209


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.33s/it]


Validation loss: 0.1166, Dice score: 0.8582
Saved new best model (Dice: 0.8582)

Epoch 131/200


Train loss: 0.1320: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 131 average loss: 0.1210

Epoch 132/200


Train loss: 0.0789: 100%|██████████| 108/108 [02:23<00:00,  1.33s/it]


Epoch 132 average loss: 0.1227

Epoch 133/200


Train loss: 0.1436: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 133 average loss: 0.1194

Epoch 134/200


Train loss: 0.0617: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 134 average loss: 0.1197

Epoch 135/200


Train loss: 0.0697: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 135 average loss: 0.1169

Epoch 136/200


Train loss: 0.1401: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 136 average loss: 0.1229

Epoch 137/200


Train loss: 0.0840: 100%|██████████| 108/108 [02:13<00:00,  1.23s/it]


Epoch 137 average loss: 0.1179

Epoch 138/200


Train loss: 0.0802: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 138 average loss: 0.1173

Epoch 139/200


Train loss: 0.0966: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 139 average loss: 0.1181

Epoch 140/200


Train loss: 0.0772: 100%|██████████| 108/108 [02:14<00:00,  1.24s/it]


Epoch 140 average loss: 0.1172


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.30s/it]


Validation loss: 0.1246, Dice score: 0.8469

Epoch 141/200


Train loss: 0.0823: 100%|██████████| 108/108 [02:15<00:00,  1.25s/it]


Epoch 141 average loss: 0.1179

Epoch 142/200


Train loss: 0.1514: 100%|██████████| 108/108 [02:15<00:00,  1.25s/it]


Epoch 142 average loss: 0.1174

Epoch 143/200


Train loss: 0.0791: 100%|██████████| 108/108 [02:12<00:00,  1.23s/it]


Epoch 143 average loss: 0.1207

Epoch 144/200


Train loss: 0.2438: 100%|██████████| 108/108 [02:15<00:00,  1.26s/it]


Epoch 144 average loss: 0.1170

Epoch 145/200


Train loss: 0.2255: 100%|██████████| 108/108 [02:21<00:00,  1.31s/it]


Epoch 145 average loss: 0.1179

Epoch 146/200


Train loss: 0.0822: 100%|██████████| 108/108 [02:20<00:00,  1.30s/it]


Epoch 146 average loss: 0.1157

Epoch 147/200


Train loss: 0.1058: 100%|██████████| 108/108 [02:22<00:00,  1.32s/it]


Epoch 147 average loss: 0.1153

Epoch 148/200


Train loss: 0.0955: 100%|██████████| 108/108 [02:17<00:00,  1.27s/it]


Epoch 148 average loss: 0.1187

Epoch 149/200


Train loss: 0.0830: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 149 average loss: 0.1178

Epoch 150/200


Train loss: 0.0849: 100%|██████████| 108/108 [02:19<00:00,  1.30s/it]


Epoch 150 average loss: 0.1143


Validation: 100%|██████████| 7/7 [00:30<00:00,  4.32s/it]


Validation loss: 0.1153, Dice score: 0.8589
Saved new best model (Dice: 0.8589)

Epoch 151/200


Train loss: 0.1142: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 151 average loss: 0.1200

Epoch 152/200


Train loss: 0.1050: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 152 average loss: 0.1150

Epoch 153/200


Train loss: 0.1289: 100%|██████████| 108/108 [02:15<00:00,  1.26s/it]


Epoch 153 average loss: 0.1141

Epoch 154/200


Train loss: 0.1859: 100%|██████████| 108/108 [02:17<00:00,  1.28s/it]


Epoch 154 average loss: 0.1172

Epoch 155/200


Train loss: 0.1221: 100%|██████████| 108/108 [02:18<00:00,  1.28s/it]


Epoch 155 average loss: 0.1164

Epoch 156/200


Train loss: 0.1112: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 156 average loss: 0.1152

Epoch 157/200


Train loss: 0.1762: 100%|██████████| 108/108 [02:18<00:00,  1.29s/it]


Epoch 157 average loss: 0.1140

Epoch 158/200


Train loss: 0.0660: 100%|██████████| 108/108 [02:19<00:00,  1.29s/it]


Epoch 158 average loss: 0.1144

Epoch 159/200


Train loss: 0.0774: 100%|██████████| 108/108 [02:16<00:00,  1.26s/it]


Epoch 159 average loss: 0.1130

Epoch 160/200


Train loss: 0.1093: 100%|██████████| 108/108 [02:16<00:00,  1.27s/it]


Epoch 160 average loss: 0.1146


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.23s/it]


Validation loss: 0.1161, Dice score: 0.8576

Epoch 161/200


Train loss: 0.0697: 100%|██████████| 108/108 [02:15<00:00,  1.26s/it]


Epoch 161 average loss: 0.1137

Epoch 162/200


Train loss: 0.1037: 100%|██████████| 108/108 [02:14<00:00,  1.25s/it]


Epoch 162 average loss: 0.1130

Epoch 163/200


Train loss: 0.0818: 100%|██████████| 108/108 [02:14<00:00,  1.24s/it]


Epoch 163 average loss: 0.1165

Epoch 164/200


Train loss: 0.1184: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 164 average loss: 0.1105

Epoch 165/200


Train loss: 0.1812: 100%|██████████| 108/108 [02:10<00:00,  1.20s/it]


Epoch 165 average loss: 0.1155

Epoch 166/200


Train loss: 0.0765: 100%|██████████| 108/108 [02:08<00:00,  1.19s/it]


Epoch 166 average loss: 0.1164

Epoch 167/200


Train loss: 0.0744: 100%|██████████| 108/108 [02:11<00:00,  1.22s/it]


Epoch 167 average loss: 0.1152

Epoch 168/200


Train loss: 0.1946: 100%|██████████| 108/108 [02:10<00:00,  1.21s/it]


Epoch 168 average loss: 0.1122

Epoch 169/200


Train loss: 0.0649: 100%|██████████| 108/108 [02:07<00:00,  1.18s/it]


Epoch 169 average loss: 0.1154

Epoch 170/200


Train loss: 0.0727: 100%|██████████| 108/108 [02:05<00:00,  1.16s/it]


Epoch 170 average loss: 0.1114


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.24s/it]


Validation loss: 0.1169, Dice score: 0.8567

Epoch 171/200


Train loss: 0.1047: 100%|██████████| 108/108 [02:13<00:00,  1.24s/it]


Epoch 171 average loss: 0.1119

Epoch 172/200


Train loss: 0.1213:  74%|███████▍  | 80/108 [01:35<00:28,  1.03s/it]

In [None]:
model = get_model()  # Initialize model
model.load_state_dict(torch.load("/kaggle/working/final_model.pth", map_location=device))  # Load weights
model = model.to(device)  # Move to CUDA if available
model.eval()  # Set to evaluation mode

In [None]:
visualize_class_dice(model, val_loader)

In [None]:
val_loader.to(device)

In [None]:
train_files, val_files = get_data_dicts(data_dir)
    
    # Get transforms
train_transforms, val_transforms = get_transforms()
    
    # Get dataloaders
train_loader, val_loader = get_dataloaders(train_files, val_files, train_transforms, val_transforms, cache=True)


In [None]:
# Add code to check dimensions of your validation data
for i, item in enumerate(val_loader):
    print(f"Sample {i}, image shape: {item['image'].shape}, label shape: {item['label'].shape}")

In [None]:
for i, item in enumerate(train_loader):
    print(f"Sample {i}, image shape: {item[0]['image'].shape}, label shape: {item[0]['label'].shape}")

In [None]:
train_loader.shape

In [None]:
for i, item in enumerate(train_loader):
    print(item[1]['image'])

In [None]:
model