In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import tifffile
import matplotlib.pyplot as plt

import random

In [2]:
DEFAULT_RANDOM_SEED = 2021

seed = DEFAULT_RANDOM_SEED

random.seed(seed)
# os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [3]:
class SWEDDataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None, target_transform=None):
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.target_transform = target_transform

        self.data_dir = os.path.join(root_dir, mode)
        self.image_dir = os.path.join(self.data_dir, 'images')
        self.label_dir = os.path.join(self.data_dir, 'labels')
        
        image_files = sorted([f for f in os.listdir(self.image_dir) 
                            if f.endswith('.npy' if mode in ['train', 'val'] else '.tif')])
        label_files = sorted([f for f in os.listdir(self.label_dir) 
                            if f.endswith('.npy' if mode in ['train', 'val'] else '.tif')])
        
        self.pairs = []
        label_suffix = '_chip_' if mode in ['train', 'val'] else '_label_'
        image_dict = {f.replace('_image_', label_suffix): f for f in image_files}
        
        for label_file in label_files:
            if label_file in image_dict:
                self.pairs.append((image_dict[label_file], label_file))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_file, label_file = self.pairs[idx]
        img_path = os.path.join(self.image_dir, img_file)
        label_path = os.path.join(self.label_dir, label_file)
        
        if self.mode in ['train', 'val']:
            image = np.load(img_path)
            label = np.load(label_path)
        else:
            image = tifffile.imread(img_path)
            label = tifffile.imread(label_path)
            
        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).float()
        
        if self.mode in ['train', 'val']:
            image = image.permute(2, 0, 1)
        elif self.mode == 'test':
            image = image.permute(0, 2, 1)
            label = label.unsqueeze(0)
            label = torch.rot90(label, 1, [1, 2])
            label = torch.flip(label, [1])

        image = image / 2.0**15     # jp2 images are 8 to 16 bit
        label = label > 0.0         # binary label
            
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        if self.target_transform:
            label = self.target_transform(label)
            
        return image, label

In [4]:
def get_dataloaders(root_dir, batch_size=32, num_workers=4, train_transform=None, test_transform=None, device='cuda'):
    train_dataset = SWEDDataset(root_dir, mode='train', transform=train_transform)
    test_dataset = SWEDDataset(root_dir, mode='test', transform=test_transform)

    # train_dataset, val_dataset, _ = random_split(train_dataset, [int(0.01 * len(train_dataset)), 
    #                                                              int(0.01 * len(train_dataset)), 
    #                                                              len(train_dataset) - int(0.02 * len(train_dataset))])

    train_dataset, val_dataset = random_split(train_dataset, [int(0.8 * len(train_dataset)),  
                                                                 len(train_dataset) - int(0.8 * len(train_dataset))])

    print(f'Train size: {len(train_dataset)}')
    print(f'Validation size: {len(val_dataset)}')
    print(f'Test size: {len(test_dataset)}')
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

In [5]:
def display_samples(dataloader, num_samples=5):
    # Get a batch
    images, masks = next(iter(dataloader))

    # Move to CPU for visualization
    images = images.cpu()
    masks = masks.cpu()

    # Only display up to the requested number of samples
    num_samples = min(num_samples, len(images))
    
    fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
    
    for idx in range(num_samples):
        # Display RGB channels (assuming bands 3,2,1 are RGB)
        rgb_img = images[idx][[3,2,1]].permute(1,2,0)
        # Normalize for visualization
        rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
        
        axes[0, idx].imshow(rgb_img)
        axes[0, idx].axis('off')
        axes[0, idx].set_title(f'Image {idx+1}')
        
        axes[1, idx].imshow(masks[idx][0], cmap='gray')
        axes[1, idx].axis('off')
        axes[1, idx].set_title(f'Mask {idx+1}')
    
    plt.tight_layout()
    plt.show()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print(f"Using device: {device}")

root_dir = "/kaggle/input/sentinel-2-water-edges-dataset/SWED"
train_loader, val_loader, test_loader = get_dataloaders(root_dir, batch_size=32, num_workers=4)

# Display 5 samples from training set
display_samples(train_loader, num_samples=5)
# display_samples(test_loader, num_samples=5)


'''
# in case we need standardization

channel-wise mean:  tensor([ 532.5187,  636.4246,  892.5240, 1049.9366, 1307.1577, 1738.9155,
        1915.7476, 1995.0083, 2055.7939, 2086.2705, 2001.6875, 1491.3577])
channel-wise std:  tensor([ 679.3956,  750.0253,  923.6580, 1273.5732, 1366.0400, 1500.5621,
        1623.3806, 1687.1169, 1720.2144, 1827.5625, 1932.8875, 1631.7715])
'''

In [None]:
!pip install -q --upgrade torchmetrics
!pip install --upgrade timm

In [8]:
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, precision_score, recall_score,
    cohen_kappa_score, f1_score, jaccard_score, matthews_corrcoef
)
from torchmetrics.classification import (
    BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryJaccardIndex, 
    BinaryCohenKappa, BinaryMatthewsCorrCoef, Accuracy
)
import pandas as pd

class Trainer:
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader, device, 
                 scheduler=None, early_stopping_patience=10, min_delta=0.001):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.scheduler = scheduler
        self.predictions = None

        # Early stopping parameters
        self.early_stopping_patience = early_stopping_patience
        self.min_delta = min_delta
        self.best_val_loss = float('inf')
        self.early_stopping_counter = 0
        self.early_stopped = False

    def save_checkpoint(self, epoch, train_loss, val_loss, best_model=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses
        }
        if self.scheduler:
            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()

        if best_model:
            save_path = 'best_model.pt'
        else:
            save_path = 'checkpoint.pt'
        torch.save(checkpoint, save_path)
    
    def load_checkpoint(self, checkpoint_path="best_model.pt"):
        if not os.path.exists(checkpoint_path):
            return 0  # Start from scratch if no checkpoint exists
            
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if self.scheduler and 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # Restore loss history
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_losses = checkpoint.get('val_losses', [])
        
        return checkpoint['epoch']

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        for images, labels in tqdm(self.train_loader, desc="train"):
            images, labels = images.to(self.device), labels.to(self.device).float()
            
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs.squeeze(), labels.squeeze())
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item() * images.size(0)
        
        return running_loss / len(self.train_loader.dataset)

    def val_epoch(self):
        self.model.eval()
        running_loss = 0.0
        all_predictions = []
        y_true = None
        y_pred = None
        with torch.no_grad():
            for images, labels in tqdm(self.val_loader, desc="validation"):
                images, labels = images.to(self.device), labels.to(self.device).float()
                outputs = self.model(images)
                if y_true is None:
                    y_true = labels.squeeze().cpu().numpy()
                    y_pred = (outputs.squeeze().cpu().numpy() > 0.5).astype(int)
                else:
                    # print(y_true.shape,labels.squeeze().cpu().numpy().shape)
                    y_true = np.concatenate((y_true, labels.squeeze().cpu().numpy()))
                    y_pred = np.concatenate((y_pred, (outputs.squeeze().cpu().numpy() > 0.5).astype(int)))
                
                loss = self.criterion(outputs.squeeze(), labels.squeeze())
                running_loss += loss.item() * images.size(0)
                all_predictions.extend(outputs.squeeze().cpu().numpy())

        metrics = self.evaluate_torchmetrics(y_pred, y_true)
        metrics.loc[len(metrics)] = ['loss', running_loss / len(self.val_loader.dataset)]
        
        return running_loss / len(self.val_loader.dataset), metrics
    
    def plot_losses(self, train_losses, val_losses):
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
        plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.show()
    
    def train(self, num_epochs, do_plot=True, plot_interval=2, resume=False):
        # Initialize or restore from checkpoint
        self.train_losses = []
        self.val_losses = []
        start_epoch = self.load_checkpoint('checkpoint.pt') if resume else 0
        self.best_val_loss = float('inf')
        self.early_stopping_counter = 0
        self.early_stopped = False

        if start_epoch > 0:
            print("Training resumed from epoch ", start_epoch)
    
        for epoch in range(start_epoch, num_epochs):
            if self.early_stopped:
                print("Early stopping triggered.")
                break

            train_loss = self.train_epoch()
            val_loss, val_metrics = self.val_epoch()
    
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            
            if self.scheduler:
                self.scheduler.step(val_loss)
    
            print(f"Epoch {epoch + 1}/{num_epochs}")
            print(f"Training Loss: {train_loss:.4f}")
            print(f"Validation Metrics: {val_metrics}")

            self.save_checkpoint(epoch + 1, train_loss, val_loss)
    
            # Early stopping logic
            if val_loss < self.best_val_loss - self.min_delta:
                self.best_val_loss = val_loss
                self.early_stopping_counter = 0
                # Save the best model
                self.save_checkpoint(epoch + 1, train_loss, val_loss, best_model=True)
                print(f"New best model saved at epoch {epoch + 1}")
            else:
                self.early_stopping_counter += 1
                print(f"No improvement. Early stopping counter: {self.early_stopping_counter}")
                
                if self.early_stopping_counter >= self.early_stopping_patience:
                    self.early_stopped = True
                    print("Early stopping triggered.")
    
            if do_plot and (epoch % plot_interval == 0 or epoch == num_epochs - 1):
                self.plot_losses(self.train_losses, self.val_losses)
        
        # Load the best model at the end of training
        if os.path.exists('best_model.pt'):
            self.load_checkpoint('best_model.pt')
        
        return self.early_stopped


    def test(self):
        self.load_checkpoint('best_model.pt')
        self.model.eval()
        running_loss = 0.0
        all_predictions = []
        y_true = None
        y_pred = None

        with torch.no_grad():
            for images, labels in tqdm(self.test_loader, desc="Testing"):
                images, labels = images.to(self.device), labels.to(self.device).float()
                outputs = self.model(images)

                if y_true is None:
                    y_true = labels.squeeze().cpu().numpy()
                    y_pred = (outputs.squeeze().cpu().numpy() > 0.5).astype(int)
                else:
                    y_true = np.concatenate((y_true, labels.squeeze().cpu().numpy()))
                    y_pred = np.concatenate((y_pred, (outputs.squeeze().cpu().numpy() > 0.5).astype(int)))
                
                loss = self.criterion(outputs.squeeze(), labels.squeeze())
                running_loss += loss.item() * images.size(0)
                all_predictions.extend(outputs.squeeze().cpu().numpy())

        metrics = self.evaluate_torchmetrics(y_pred, y_true)
        metrics.loc[len(metrics)] = ['loss', running_loss / len(self.test_loader.dataset)]  

        return metrics

    def evaluate_torchmetrics(self, y_pred, y_true):
        metrics = {
            "accuracy": BinaryAccuracy(),
            "bal_accuracy": Accuracy(num_classes=2, task="multiclass", average="macro"),
            "precision": BinaryPrecision(),
            "recall": BinaryRecall(),
            "f1_score": BinaryF1Score(),
            "jaccard_index": BinaryJaccardIndex(),
            "cohen_kappa": BinaryCohenKappa(),
            "mcc": BinaryMatthewsCorrCoef()
        }
    
        y_pred = torch.tensor(y_pred).float()
        y_true = torch.tensor(y_true).float()
    
        # result dataframe
        results = pd.DataFrame(columns=["Metric", "Value"])
        for metric_name, metric in metrics.items():
            metric_value = metric(y_pred, y_true)
            results.loc[len(results)] = [metric_name, metric_value.item()]  
            
        return results

    def test_visualize(self, n_samples=5): 
        self.model.eval()
        y_true = None
        y_pred = None

        with torch.no_grad():
            for images, labels in tqdm(self.test_loader, desc="Testing"):
                images, labels = images.to(self.device), labels.to(self.device).float()
                outputs = self.model(images)

                if y_true is None:
                    y_true = labels.squeeze().cpu().numpy()
                    y_pred = (outputs.squeeze().cpu().numpy() > 0.5).astype(int)
                else:
                    y_true = np.concatenate((y_true, labels.squeeze().cpu().numpy()))
                    y_pred = np.concatenate((y_pred, (outputs.squeeze().cpu().numpy() > 0.5).astype(int)))
                
        for _ in range(n_samples):
            idx = random.randint(0, len(y_true))
            fig, axes = plt.subplots(1, 2, figsize=(10, 5))
            axes[0].imshow(y_true[idx].reshape(256, 256), cmap='gray')
            axes[0].set_title('True')
            axes[1].imshow(y_pred[idx].reshape(256, 256), cmap='gray')
            axes[1].set_title('Predicted')
            plt.show()


In [9]:
# # import torch
# # import torch.nn as nn
# from torchvision import models

# class DeepLabV3(nn.Module):
#     def __init__(self, num_classes=1):
#         super(DeepLabV3, self).__init__()
#         # Load pretrained DeepLabV3 model with ResNet50 backbone
#         self.model = models.segmentation.deeplabv3_resnet50(pretrained=False)
        
#         # Modify the first layer to accept 12 channels instead of 3
#         self.model.backbone.conv1 = nn.Conv2d(
#             in_channels=12,  # Change input channels to 12
#             out_channels=self.model.backbone.conv1.out_channels,
#             kernel_size=self.model.backbone.conv1.kernel_size,
#             stride=self.model.backbone.conv1.stride,
#             padding=self.model.backbone.conv1.padding,
#             bias=self.model.backbone.conv1.bias is not None
#         )
        
#         # Modify the classifier to output the desired number of classes
#         self.model.classifier[4] = nn.Conv2d(
#             in_channels=256,
#             out_channels=num_classes,
#             kernel_size=(1, 1),
#             stride=(1, 1)
#         )
    
#     def forward(self, x):
#         return self.model(x)['out']

# model = DeepLabV3(num_classes=1).to(device)

In [10]:
# import torch


# def Conv2dSame(in_channels, out_channels, kernel_size, use_bias=True, padding_layer=torch.nn.ReflectionPad2d):
#     ka = kernel_size // 2
#     kb = ka - 1 if kernel_size % 2 == 0 else ka
#     return [
#         padding_layer((ka, kb, ka, kb)),
#         torch.nn.Conv2d(in_channels, out_channels, kernel_size, bias=use_bias)
#     ]


# def conv2d_bn(in_channels, filters, kernel_size, padding='same', activation='relu'):
#     assert padding == 'same'
#     affine = False if activation == 'relu' or activation == 'sigmoid' else True
#     sequence = []
#     sequence += Conv2dSame(in_channels, filters, kernel_size, use_bias=False)
#     sequence += [torch.nn.BatchNorm2d(filters, affine=affine)]
#     if activation == "relu":
#         sequence += [torch.nn.ReLU()]
#     elif activation == "sigmoid":
#         sequence += [torch.nn.Sigmoid()]
#     elif activation == 'tanh':
#         sequence += [torch.nn.Tanh()]
#     return torch.nn.Sequential(*sequence)


# class MultiResBlock(torch.nn.Module):
#     def __init__(self, in_channels, u, alpha=1.67, use_dropout=False):
#         super().__init__()
#         w = alpha * u
#         self.out_channel = int(w * 0.167) + int(w * 0.333) + int(w * 0.5)
#         self.conv2d_bn = conv2d_bn(in_channels, self.out_channel, 1, activation=None)
#         self.conv3x3 = conv2d_bn(in_channels, int(w * 0.167), 3, activation='relu')
#         self.conv5x5 = conv2d_bn(int(w * 0.167), int(w * 0.333), 3, activation='relu')
#         self.conv7x7 = conv2d_bn(int(w * 0.333), int(w * 0.5), 3, activation='relu')
#         self.bn_1 = torch.nn.BatchNorm2d(self.out_channel)
#         self.relu = torch.nn.ReLU()
#         self.bn_2 = torch.nn.BatchNorm2d(self.out_channel)
#         self.use_dropout = use_dropout
#         if use_dropout:
#             self.dropout = torch.nn.Dropout(0.5)

#     def forward(self, inp):
#         if self.use_dropout:
#             x = self.dropout(inp)
#         else:
#             x = inp
#         shortcut = self.conv2d_bn(x)
#         conv3x3 = self.conv3x3(x)
#         conv5x5 = self.conv5x5(conv3x3)
#         conv7x7 = self.conv7x7(conv5x5)
#         out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1)
#         out = self.bn_1(out)
#         out = torch.add(shortcut, out)
#         out = self.relu(out)
#         out = self.bn_2(out)
#         return out


# class ResPathBlock(torch.nn.Module):
#     def __init__(self, in_channels, filters):
#         super(ResPathBlock, self).__init__()
#         self.conv2d_bn1 = conv2d_bn(in_channels, filters, 1, activation=None)
#         self.conv2d_bn2 = conv2d_bn(in_channels, filters, 3, activation='relu')
#         self.relu = torch.nn.ReLU()
#         self.bn = torch.nn.BatchNorm2d(filters)

#     def forward(self, inp):
#         shortcut = self.conv2d_bn1(inp)
#         out = self.conv2d_bn2(inp)
#         out = torch.add(shortcut, out)
#         out = self.relu(out)
#         out = self.bn(out)
#         return out


# class ResPath(torch.nn.Module):
#     def __init__(self, in_channels, filters, length):
#         super(ResPath, self).__init__()
#         self.first_block = ResPathBlock(in_channels, filters)
#         self.blocks = torch.nn.Sequential(*[ResPathBlock(filters, filters) for i in range(length - 1)])

#     def forward(self, inp):
#         out = self.first_block(inp)
#         out = self.blocks(out)
#         return out


# class MultiResUnet(torch.nn.Module):
#     def __init__(self, in_channels, out_channels, nf=32, use_dropout=False):
#         super(MultiResUnet, self).__init__()
#         self.mres_block1 = MultiResBlock(in_channels, u=nf)
#         self.pool = torch.nn.MaxPool2d(kernel_size=2)
#         self.res_path1 = ResPath(self.mres_block1.out_channel, nf, 4)

#         self.mres_block2 = MultiResBlock(self.mres_block1.out_channel, u=nf * 2)
#         # self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
#         self.res_path2 = ResPath(self.mres_block2.out_channel, nf * 2, 3)

#         self.mres_block3 = MultiResBlock(self.mres_block2.out_channel, u=nf * 4)
#         # self.pool3 = torch.nn.MaxPool2d(kernel_size=2)
#         self.res_path3 = ResPath(self.mres_block3.out_channel, nf * 4, 2)

#         self.mres_block4 = MultiResBlock(self.mres_block3.out_channel, u=nf * 8)
#         # self.pool4 = torch.nn.MaxPool2d(kernel_size=2)
#         self.res_path4 = ResPath(self.mres_block4.out_channel, nf * 8, 1)

#         self.mres_block5 = MultiResBlock(self.mres_block4.out_channel, u=nf * 16)

#         self.deconv1 = torch.nn.ConvTranspose2d(self.mres_block5.out_channel, nf * 8, (2, 2), (2, 2))
#         self.mres_block6 = MultiResBlock(nf * 8 + nf * 8, u=nf * 8, use_dropout=use_dropout)
#         # MultiResBlock(nf * 8 + self.mres_block4.out_channel, u=nf * 8)

#         self.deconv2 = torch.nn.ConvTranspose2d(self.mres_block6.out_channel, nf * 4, (2, 2), (2, 2))
#         self.mres_block7 = MultiResBlock(nf * 4 + nf * 4, u=nf * 4, use_dropout=use_dropout)
#         # MultiResBlock(nf * 4 + self.mres_block3.out_channel, u=nf * 4)

#         self.deconv3 = torch.nn.ConvTranspose2d(self.mres_block7.out_channel, nf * 2, (2, 2), (2, 2))
#         self.mres_block8 = MultiResBlock(nf * 2 + nf * 2, u=nf * 2, use_dropout=use_dropout)
#         # MultiResBlock(nf * 2 + self.mres_block2.out_channel, u=nf * 2)

#         self.deconv4 = torch.nn.ConvTranspose2d(self.mres_block8.out_channel, nf, (2, 2), (2, 2))
#         self.mres_block9 = MultiResBlock(nf + nf, u=nf)
#         # MultiResBlock(nf + self.mres_block1.out_channel, u=nf)

#         self.conv10 = conv2d_bn(self.mres_block9.out_channel, out_channels, 1, padding='same', activation='tanh')

#     def forward(self, inp):
#         mresblock1 = self.mres_block1(inp)
#         pool = self.pool(mresblock1)
#         mresblock1 = self.res_path1(mresblock1)

#         mresblock2 = self.mres_block2(pool)
#         pool = self.pool(mresblock2)
#         mresblock2 = self.res_path2(mresblock2)

#         mresblock3 = self.mres_block3(pool)
#         pool = self.pool(mresblock3)
#         mresblock3 = self.res_path3(mresblock3)

#         mresblock4 = self.mres_block4(pool)
#         pool = self.pool(mresblock4)
#         mresblock4 = self.res_path4(mresblock4)

#         mresblock = self.mres_block5(pool)

#         up = torch.cat([self.deconv1(mresblock), mresblock4], dim=1)
#         mresblock = self.mres_block6(up)

#         up = torch.cat([self.deconv2(mresblock), mresblock3], dim=1)
#         mresblock = self.mres_block7(up)

#         up = torch.cat([self.deconv3(mresblock), mresblock2], dim=1)
#         mresblock = self.mres_block8(up)

#         up = torch.cat([self.deconv4(mresblock), mresblock1], dim=1)
#         mresblock = self.mres_block9(up)

#         conv10 = self.conv10(mresblock)
#         return conv10


# class MultiResUnetGenerator(torch.nn.Module):
#     def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, gpu_ids=[]):
#         super(MultiResUnetGenerator, self).__init__()
#         self.gpu_ids = gpu_ids

#         self.model = MultiResUnet(input_nc, output_nc, nf=ngf, use_dropout=use_dropout)

#     def forward(self, inp):
#         if self.gpu_ids and isinstance(inp.data, torch.cuda.FloatTensor):
#             return torch.nn.parallel.data_parallel(self.model, inp, self.gpu_ids)
#         else:
#             return self.model(inp)


# def weights_init_uniform_rule(m):
#     classname = m.__class__.__name__
#     # for every Linear layer in a model..
#     if classname == 'Conv2d':
#         pass
#     # print(classname)


# # a = ResPath(10, 100,3)
# # a.apply(weights_init_uniform_rule)
# # a = MultiResUnet(512, 512, 3)
# # x = torch.randn(2, 3, 512, 512)
# # print(a(x).shape
# model = MultiResUnet(in_channels=12, out_channels=1, nf=32)
# model = model.to(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViTBinarySegmenter(nn.Module):
    def __init__(self, input_channels=3, img_size=256, patch_size=16):
        super(ViTBinarySegmenter, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size

        # Learnable embedding for patches
        self.patch_embed = nn.Conv2d(
            input_channels, 768, kernel_size=patch_size, stride=patch_size, bias=False
        )

        # Positional embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, (img_size // patch_size) ** 2, 768)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # Transformer encoder
        self.transformer = nn.Transformer(
            d_model=768, nhead=4, num_encoder_layers=4, dim_feedforward=3072, dropout=0.1
        )

        # Decoder layers to upsample and predict the segmentation map
        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, self.patch_size*self.patch_size, kernel_size=1)
        )

    def forward(self, x):
        batch_size, _, h, w = x.shape

        # Ensure input is compatible with fixed image size
        if h != self.img_size or w != self.img_size:
            x = F.interpolate(x, size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)

        # Patch embedding
        patches = self.patch_embed(x)  # Shape: (B, 768, H/P, W/P)
        patches = patches.flatten(2).transpose(1, 2)  # Shape: (B, N, 768)

        # Add positional embeddings
        patches = patches + self.pos_embed

        # Pass through transformer
        transformed = self.transformer(patches, patches)  # Shape: (B, N, 768)

        # Reshape back to spatial dimensions
        features = transformed.permute(0, 2, 1).reshape(
            batch_size, 768, self.img_size // self.patch_size, self.img_size // self.patch_size
        )

        # Decode the features into a segmentation map
        seg_map = self.decoder(features)

        # Resize output to match the input size
        # seg_map = F.interpolate(seg_map, size=(h, w), mode='bilinear', align_corners=False)
        seg_map = seg_map.reshape(batch_size, h, w)

        return torch.sigmoid(seg_map)


model = ViTBinarySegmenter(input_channels=12, img_size=256, patch_size=16).to(device)

In [12]:
# model = torch.hub.load(
#     'mateuszbuda/brain-segmentation-pytorch', 
#     'unet',
#     in_channels=12,  # Set input channels to 12
#     out_channels=1,  # Output channel remains 1 for the mask
#     init_features=32, 
#     pretrained=False  # Pretraining is not available for 12 channels
# )

# model = model.to(device)

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.65, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce_loss = nn.BCEWithLogitsLoss(reduction="none")
    
    def forward(self, pred, target):
        pred, target = pred.squeeze(), target.squeeze()
        loss = self.bce_loss(pred, target)
        prob = torch.sigmoid(pred)  # Predicted probability
        alpha = torch.where(target == 1, self.alpha, 1 - self.alpha)  # Class balancing factor
        focal_weight = torch.where(target == 1, 1 - prob, prob)  # Focusing weight
        focal_weight = alpha * focal_weight**self.gamma  # Apply alpha and gamma
        focal_loss = focal_weight * loss
        
        focal_loss = focal_loss.sum(dim=(-2,-1)) #*mask
        return focal_loss.mean()


pred = torch.randn(16, 128, 128).float()
target = torch.randint(0, 2, (16, 128, 128)).float()

focal_loss = FocalLoss(alpha=0.75, gamma=2.0)
focal_loss(pred, target)

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd

criterion = FocalLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, test_loader, device, scheduler)
trainer.train(num_epochs=20)

In [None]:
metrics_df = trainer.test()
metrics_df

In [None]:
trainer.test_visualize(15)