#1. Setup and Imports:

Import necessary libraries.
Check for GPU availability (important for Google Colab).

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. Data Preprocessing
Define functions to load and preprocess data


In [None]:
from tifffile import TiffFile
import numpy as np
import torch
from torchvision import transforms

def load_image(image_path):
    with TiffFile(image_path) as tif:
        image = tif.asarray()
    # Normalize the image to [0, 1]
    image = image.astype(np.float32) / 65535.0
    return image

def load_mask(mask_path):
    with TiffFile(mask_path) as tif:
        mask = tif.asarray()
    # Directly return the mask as it is (assuming it's already in an appropriate range 0-9)
    return mask


Exemple usage

# 3. Dataset and DataLoader


*   Create a custom PyTorch Dataset class for your Sentinel-2 data.

*   Use PyTorch DataLoader to batch and shuffle the data.

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [None]:
class Sentinel2Dataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.images = os.listdir(images_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.images[idx])
        mask_name = os.path.join(self.masks_dir, self.images[idx])

        image = load_image(img_name)  # Normalized image
        mask = load_mask(mask_name)   # Categorical mask

        # Convert to PyTorch tensors
        image_tensor = torch.from_numpy(image).float().permute(2, 0, 1)
        mask_tensor = torch.from_numpy(mask).long()  # Convert mask to long tensor

        return image_tensor, mask_tensor

In [None]:
# Define directories
images_dir = '/kaggle/input/dataset-tar-gz/train/images/'
masks_dir = '/kaggle/input/dataset-tar-gz/train/masks/'

# Instantiate the dataset
sentinel_dataset = Sentinel2Dataset(images_dir=images_dir, masks_dir=masks_dir)

# Define batch size
batch_size = 16

# Create DataLoader
data_loader = DataLoader(sentinel_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


In [None]:
# Define validation directories
val_images_dir = '/kaggle/input/dataset-tar-gz/test/images/'
val_masks_dir = '/kaggle/input/dataset-tar-gz/test/masks/'

# Instantiate the validation dataset
val_dataset = Sentinel2Dataset(images_dir=val_images_dir, masks_dir=val_masks_dir)

# Create validation DataLoader
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


#4. Model Architecture (U-Net):

* Define the U-Net architecture.
* Modify the final layer to output 10 channels (one for each category).

In [None]:
import torch
import torch.nn as nn

def conv_block(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,)
    return model


def conv_trans_block(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn,)
    return model


def maxpool():
    pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return pool


def conv_block_2(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        conv_block(in_dim, out_dim, act_fn),
        nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model

def conv_block_3(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        conv_block(in_dim, out_dim, act_fn),
        conv_block(out_dim, out_dim, act_fn),
        nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim),
    )
    return model

class UNet(nn.Module):
    def __init__(self, in_dim, out_dim, num_filter):
        super(UNet, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.LeakyReLU(0.2, inplace=True)

        self.down_1 = conv_block_2(self.in_dim, self.num_filter, act_fn)
        self.pool_1 = maxpool()
        self.down_2 = conv_block_2(self.num_filter * 1, self.num_filter * 2, act_fn)
        self.pool_2 = maxpool()
        self.down_3 = conv_block_2(self.num_filter * 2, self.num_filter * 4, act_fn)
        self.pool_3 = maxpool()
        self.down_4 = conv_block_2(self.num_filter * 4, self.num_filter * 8, act_fn)
        self.pool_4 = maxpool()

        self.bridge = conv_block_2(self.num_filter * 8, self.num_filter * 16, act_fn)

        self.trans_1 = conv_trans_block(self.num_filter * 16, self.num_filter * 8, act_fn)
        self.up_1 = conv_block_2(self.num_filter * 16, self.num_filter * 8, act_fn)
        self.trans_2 = conv_trans_block(self.num_filter * 8, self.num_filter * 4, act_fn)
        self.up_2 = conv_block_2(self.num_filter * 8, self.num_filter * 4, act_fn)
        self.trans_3 = conv_trans_block(self.num_filter * 4, self.num_filter * 2, act_fn)
        self.up_3 = conv_block_2(self.num_filter * 4, self.num_filter * 2, act_fn)
        self.trans_4 = conv_trans_block(self.num_filter * 2, self.num_filter * 1, act_fn)
        self.up_4 = conv_block_2(self.num_filter * 2, self.num_filter * 1, act_fn)

        self.out = nn.Conv2d(self.num_filter, self.out_dim, 3, 1, 1)


    def forward(self, input):
        down_1 = self.down_1(input)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)
        down_4 = self.down_4(pool_3)
        pool_4 = self.pool_4(down_4)

        bridge = self.bridge(pool_4)

        trans_1 = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1, down_4], dim=1)
        up_1 = self.up_1(concat_1)
        trans_2 = self.trans_2(up_1)
        concat_2 = torch.cat([trans_2, down_3], dim=1)
        up_2 = self.up_2(concat_2)
        trans_3 = self.trans_3(up_2)
        concat_3 = torch.cat([trans_3, down_2], dim=1)
        up_3 = self.up_3(concat_3)
        trans_4 = self.trans_4(up_3)
        concat_4 = torch.cat([trans_4, down_1], dim=1)
        up_4 = self.up_4(concat_4)

        out = self.out(up_4)

        return out

## Explanation

* DoubleConv: A helper class to perform two consecutive sets of convolution,batch normalization, and ReLU operations.
* UNet: The main U-Net architecture, which includes:
  * Initial Convolution (inc): The first layer of the U-Net, increasing the number of channels.
  * Downsampling Path (down1, down2, down3, down4): Each step involves a DoubleConv operation.
  * Upsampling Path (up1, up2, up3, up4): Each step involves a DoubleConv operation with concatenated feature maps from the corresponding downsampling layer and the previous upsampling layer.
  * Final Convolution (outc): A 1x1 convolution to map the final feature maps to the number of classes.

In [None]:
# Model parameters
in_dim = 4  # For RGB+NIR input
out_dim = 10  # Number of classes
num_filter = 64  # Number of filters in the first layer

# Create the U-Net model instance
model = UNet(in_dim,out_dim,num_filter)

In [None]:
# Check if GPU is available and move the model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model= nn.DataParallel(model)
model.to(device)



In [None]:
print(model)


# Loss Function and Optimizer:

* Implement or use an existing Dice Loss function.
* Choose an optimizer (e.g., Adam).

In [None]:
import torch
from torch import Tensor


def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    assert input.dim() == 3 or not reduce_batch_first

    sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

    inter = 2 * (input * target).sum(dim=sum_dim)
    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return dice.mean()


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all classes
    return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
    # Dice loss (objective to minimize) between 0 and 1
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

In [None]:
learning_rate = 1e-4

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
#
criterion = nn.CrossEntropyLoss()


In [None]:
import torch.nn.functional as F
from tqdm import tqdm
from torch import optim
import torch.utils.data as data
import torchvision.utils as utils
from pathlib import Path
import tifffile
import numpy as np



In [None]:

def save_input_as_tiff(input_tensor, filename):
    # Revert the normalization (assuming original range was 0-65535)
    array = (input_tensor.detach().cpu().numpy() * 65535).astype(np.uint16)

    # If the tensor is a 3D tensor (C, H, W), convert it to (H, W, C)
    if array.ndim == 3:
        array = np.transpose(array, (1, 2, 0))

    # Save as TIFF in original format
    tifffile.imwrite(filename, array)

def save_mask_as_tiff(mask_tensor, filename):
    # Convert to NumPy array
    array = mask_tensor.detach().cpu().numpy()

    # No need to transpose masks as they should be [H, W]
    tifffile.imwrite(filename, array)


In [None]:
      # Create a directory if it is not there, so we can save files and results in it
      from pathlib import Path
      Path('/kaggle/working/result/predicted').mkdir(parents=True, exist_ok=True)  
      Path('/kaggle/working/result/original').mkdir(parents=True, exist_ok=True)
      Path('/kaggle/working/result/label').mkdir(parents=True, exist_ok=True)
    
      Path('/kaggle/working/val/predicted').mkdir(parents=True, exist_ok=True)  
      Path('/kaggle/working/val/original').mkdir(parents=True, exist_ok=True)
      Path('/kaggle/working/val/label').mkdir(parents=True, exist_ok=True)

In [None]:
import pandas as pd

# Initialize an empty DataFrame
metrics_data = pd.DataFrame(columns=['Epoch', 'Train Loss', 'Validation Loss', 'Validation Dice', 'Validation Sensitivity', 'Validation Specificity', 'Validation Precision', 'Validation IOU'])


num_epochs = 10  # Set the number of epochs


for epoch in range(num_epochs):
  model.train()
  total_train_loss = 0.0

  # Set up tqdm progress bar
  progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Epoch {epoch + 1}/{num_epochs}")
  for batch_idx, (inputs, targets) in progress_bar:
      inputs = inputs.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
      targets = targets.to(device=device, dtype=torch.long)

      # Forward pass
      outputs = model(inputs)

      # Loss computation
      loss = criterion(outputs, targets)
      loss += dice_loss(
            F.softmax(outputs, dim=1).float(),
            F.one_hot(targets, model.module.out_dim).permute(0, 3, 1, 2).float(),
            multiclass=True
        )
      # Backward and optimize
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      total_train_loss += loss.item()
      progress_bar.set_postfix({
            'loss': f'{total_train_loss / (batch_idx + 1):.4f}'
        })
      if batch_idx % 100 == 0 or batch_idx == len(data_loader)-1:
          save_input_as_tiff(inputs[0], f"/kaggle/working/result/original/original_image_{batch_idx}_{epoch}.tif")
          save_mask_as_tiff(targets[0], f"/kaggle/working/result/label/label_image_{batch_idx}_{epoch}.tif")
          predicted_mask = torch.argmax(outputs[0],dim=0)
          save_mask_as_tiff(predicted_mask,f"/kaggle/working/result/predicted/predicted_image_{batch_idx}_{epoch}.tif")
          torch.save(model.module.state_dict(),Path(f'/kaggle/working/result/unet_{batch_idx}_{epoch}.pkl'))

    # Validation Loop
  sensitivities = []
  specificities = []
  precisions = []
  ious = []        
    
  model.eval()
  total_val_loss = 0.0
  total_val_dice = 0
  _ = 0
  with torch.no_grad():
      for inputs, targets in val_loader:
          _ += 1
          inputs, targets = inputs.to(device), targets.to(device)
          outputs = model(inputs)
          loss = criterion(outputs, targets)
          total_val_loss += loss.item()
                    # Convert outputs to probabilities
          probabilities = torch.softmax(outputs, dim=1)

          # Convert targets to one-hot format
          target_one_hot = F.one_hot(targets, num_classes=model.module.out_dim).permute(0, 3, 1, 2).float()

          # Calculate Dice coefficient
          dice_score = multiclass_dice_coeff(probabilities, target_one_hot)
          total_val_dice += dice_score
          if _ % 100 == 0 or _ == len(val_loader) - 1:
              save_input_as_tiff(inputs[0], f"/kaggle/working/val/original/original_image_{batch_idx}_{epoch}.tif")
              save_mask_as_tiff(targets[0], f"/kaggle/working/val/label/label_image_{batch_idx}_{epoch}.tif")
              predicted_mask = torch.argmax(outputs[0],dim=0)
              save_mask_as_tiff(predicted_mask,f"/kaggle/working/val/predicted/predicted_image_{batch_idx}_{epoch}.tif")

                  # Calculate additional metrics
          preds = torch.argmax(outputs, dim=1)
          for class_idx in range(model.module.out_dim):
            true_positive = (preds == class_idx) & (targets == class_idx)
            true_negative = (preds != class_idx) & (targets != class_idx)
            false_positive = (preds == class_idx) & (targets != class_idx)
            false_negative = (preds != class_idx) & (targets == class_idx)

            TP = true_positive.sum().item()
            TN = true_negative.sum().item()
            FP = false_positive.sum().item()
            FN = false_negative.sum().item()

            sensitivity = TP / (TP + FN) if (TP + FN) != 0 else 0
            specificity = TN / (TN + FP) if (TN + FP) != 0 else 0
            precision = TP / (TP + FP) if (TP + FP) != 0 else 0
            iou = TP / (TP + FP + FN) if (TP + FP + FN) != 0 else 0

            sensitivities.append(sensitivity)
            specificities.append(specificity)
            precisions.append(precision)
            ious.append(iou)  

  #Calculate average metrics
  avg_train_loss = total_train_loss / len(data_loader)
  avg_val_loss = total_val_loss / len(val_loader)
  avg_val_dice = total_val_dice / len(val_loader)
  avg_sensitivity = sum(sensitivities) / len(sensitivities)
  avg_specificity = sum(specificities) / len(specificities)
  avg_precision = sum(precisions) / len(precisions)
  avg_iou = sum(ious) / len(ious)

  # Append the current epoch's metrics to the DataFrame
  current_epoch_data = {
        'Epoch': epoch + 1,
        'Train Loss': avg_train_loss,
        'Validation Loss': avg_val_loss,
        'Validation Dice': avg_val_dice,
        'Validation Sensitivity': avg_sensitivity,
        'Validation Specificity': avg_specificity,
        'Validation Precision': avg_precision,
        'Validation IOU': avg_iou
    }
  metrics_data = metrics_data.append(current_epoch_data, ignore_index=True)
  # Optionally save to CSV after each epoch
  metrics_data.to_csv('/kaggle/working/result/metrics_data.csv', index=False)
