In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from PIL import Image
import os
import pandas as pd
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
import math # For PositionalEncoding2D if it's in the same file, otherwise import from model file

# --- Assuming HybridAttentionTransformerUNet is in a file named model.py ---
# Or paste the model definition (HybridAttentionTransformerUNet and its sub-modules) directly here
# For this example, let's assume it's defined above or in an imported file.

# --- Model Definition (Paste HybridAttentionTransformerUNet and its components here if not importing) ---
# --- 1. Convolutional Block ---
class ConvBlock(nn.Module):
    """
    Standard Double Convolutional Block: (Convolution -> BatchNorm -> ReLU) * 2
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# --- 2. Attention Gate ---
class AttentionGate(nn.Module):
    """
    Attention Gate to focus on relevant features from skip connections.
    """
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# --- 3. Transformer Encoder Block ---
class TransformerEncoderBlock(nn.Module):
    """
    A single Transformer Encoder layer.
    """
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout_rate)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.norm2(x)
        return x

# --- 4. 2D Positional Encoding ---
class PositionalEncoding2D(nn.Module):
    """
    Adds 2D positional encodings to the input feature map.
    """
    def __init__(self, d_model, height, width):
        super(PositionalEncoding2D, self).__init__()
        if d_model % 4 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(d_model))
        pe = torch.zeros(d_model, height, width)
        d_model_half = d_model // 2
        div_term = torch.exp(torch.arange(0., d_model_half, 2) * -(math.log(10000.0) / d_model_half))
        pos_w = torch.arange(0., width).unsqueeze(1)
        pos_h = torch.arange(0., height).unsqueeze(1)

        pe[0:d_model_half:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[1:d_model_half:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model_half::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[d_model_half+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :x.size(2), :x.size(3)]

# --- 5. Hybrid Attention-Transformer U-Net ---
class HybridAttentionTransformerUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512],
                 transformer_embed_dim=512, transformer_num_heads=8, transformer_ff_dim=2048,
                 transformer_num_layers=4, transformer_dropout=0.1, bottleneck_img_size=(16,16)):
        super(HybridAttentionTransformerUNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck_img_h, self.bottleneck_img_w = bottleneck_img_size

        if features[-1] != transformer_embed_dim:
            raise ValueError(f"Last feature size ({features[-1]}) must match transformer_embed_dim ({transformer_embed_dim})")

        current_in_channels = in_channels
        for feature in features:
            self.downs.append(ConvBlock(current_in_channels, feature))
            current_in_channels = feature

        self.pos_encoder = PositionalEncoding2D(transformer_embed_dim, self.bottleneck_img_h, self.bottleneck_img_w)
        self.transformer_encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(transformer_embed_dim, transformer_num_heads, transformer_ff_dim, transformer_dropout)
            for _ in range(transformer_num_layers)
        ])

        self.attentions = nn.ModuleList()
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.attentions.append(AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2))
            self.ups.append(ConvBlock(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for i, down_block in enumerate(self.downs):
            x = down_block(x)
            skip_connections.append(x)
            if i < len(self.downs) - 1:
                 x = self.pool(x)

        if x.shape[2] != self.bottleneck_img_h or x.shape[3] != self.bottleneck_img_w:
            x = F.adaptive_avg_pool2d(x, (self.bottleneck_img_h, self.bottleneck_img_w))

        x = self.pos_encoder(x)
        batch_size, channels, height, width = x.shape
        x = x.flatten(2).transpose(1, 2)
        for transformer_layer in self.transformer_encoder_layers:
            x = transformer_layer(x)
        x = x.transpose(1, 2).reshape(batch_size, channels, height, width)

        skip_connections = skip_connections[::-1]
        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip_connection = skip_connections[i//2]
            if x.shape != skip_connection.shape:
                 x = F.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=False)
            attention_map = self.attentions[i//2](g=x, x=skip_connection)
            concat_skip = torch.cat((attention_map, x), dim=1)
            x = self.ups[i+1](concat_skip)
        x = self.final_conv(x)
        if self.final_conv.out_channels == 1:
            return torch.sigmoid(x)
        else:
            return x
# --- End of Model Definition ---


# --- Configuration & Hyperparameters ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
BATCH_SIZE = 4 # Adjust based on your GPU memory
NUM_EPOCHS = 50 # Adjust as needed
IMAGE_HEIGHT = 256 # Desired input height
IMAGE_WIDTH = 256  # Desired input width
NUM_WORKERS = 2 # For DataLoader

# --- IMPORTANT: Paths to your data ---
# Make sure these paths are correct and the CSV file is structured as described.
CSV_FILE_PATH = "square_dataset/dataset.csv"  # Replace with your CSV file path
IMAGE_DIR_ROOT = "square_dataset" # Base directory if paths in CSV are relative
# Example: if CSV has 'subject1/nir_01.png' and IMAGE_DIR_ROOT is '/data/veins/',
# then full path is '/data/veins/subject1/nir_01.png'
# If paths in CSV are absolute, IMAGE_DIR_ROOT can be an empty string or not used.

# Choose which image type to use from CSV: 'nir_image' or 'preprocessed_image'
INPUT_IMAGE_COLUMN = 'preprocessed_images' # or 'preprocessed_image'
MASK_COLUMN = 'mask'

# --- Dataset Class ---
class VeinDataset(Dataset):
    def __init__(self, csv_file, image_dir_root, image_column, mask_column, transform=None, is_train=True):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            image_dir_root (string): Directory with all the images if paths in CSV are relative.
            image_column (string): Column name in CSV for input image paths.
            mask_column (string): Column name in CSV for mask image paths.
            transform (callable, optional): Optional transform to be applied on a sample.
            is_train (bool): If true, dataset is for training (used for specific augmentations).
        """
        try:
            self.data_frame = pd.read_csv(csv_file)
        except FileNotFoundError:
            print(f"Error: CSV file not found at {csv_file}")
            print("Please ensure the CSV_FILE_PATH is correct.")
            raise
        
        self.image_dir_root = image_dir_root
        self.image_column = image_column
        self.mask_column = mask_column
        self.transform = transform
        self.is_train = is_train

        # Correct typo in 'genere' if it exists in your actual CSV, otherwise use 'gender'
        # For this example, we'll assume the CSV has 'gender' or you've corrected 'genere'
        # self.data_frame.rename(columns={'genere': 'gender'}, inplace=True, errors='ignore')


    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name_relative = self.data_frame.loc[idx, self.image_column]
        mask_name_relative = self.data_frame.loc[idx, self.mask_column]

        # Construct full paths
        img_path = os.path.join(self.image_dir_root, img_name_relative) if self.image_dir_root else img_name_relative
        mask_path = os.path.join(self.image_dir_root, mask_name_relative) if self.image_dir_root else mask_name_relative
        
        try:
            # Load image (ensure it's loaded as grayscale if in_channels=1)
            image = np.array(Image.open(img_path).convert("L")) # "L" for grayscale
            # Load mask (ensure it's loaded as grayscale)
            mask = np.array(Image.open(mask_path).convert("L")) # "L" for grayscale
        except FileNotFoundError as e:
            print(f"Error loading image or mask at index {idx}: {e}")
            print(f"Attempted image path: {img_path}")
            print(f"Attempted mask path: {mask_path}")
            # Return None or raise error, or return a placeholder
            # For simplicity, we'll raise it here. Consider more robust error handling.
            raise

        # Binarize mask: Veins are often white (255) and background black (0).
        # Adjust threshold if necessary based on your mask format.
        mask[mask == 255.0] = 1.0 # Vein class
        mask[mask != 1.0] = 0.0   # Background class
        mask = mask.astype(np.float32)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        # Ensure mask is (1, H, W) and not (H, W) for BCEWithLogitsLoss
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)

        return image, mask

# --- Transforms/Augmentations ---
# Define different transforms for training and validation
train_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Rotate(limit=35, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.Normalize(
        mean=[0.0], # For grayscale, use single mean/std
        std=[1.0],  # Or calculate from your dataset: image.mean()/255.0, image.std()/255.0
        max_pixel_value=255.0,
    ),
    ToTensorV2(), # Converts image and mask to PyTorch tensors
])

val_transform = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
    A.Normalize(
        mean=[0.0],
        std=[1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])


# --- Loss Function ---
# BCEWithLogitsLoss is common for binary segmentation.
# Consider Dice Loss or a combination for better handling of class imbalance.
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs) # Apply sigmoid if model outputs logits

        # Flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
        return 1 - dice

# --- Utility Functions ---
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model, optimizer=None):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])
    if optimizer and 'optimizer' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
    return checkpoint.get('epoch', 0) # Return epoch if saved

def check_accuracy(loader, model, loss_fn, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score_val = 0
    total_loss = 0
    model.eval() # Set model to evaluation mode

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device=device) # y should be (N, 1, H, W)

            preds_logits = model(x)
            loss = loss_fn(preds_logits, y)
            total_loss += loss.item()

            preds_probs = torch.sigmoid(preds_logits)
            preds_binary = (preds_probs > 0.5).float()

            num_correct += (preds_binary == y).sum()
            num_pixels += torch.numel(preds_binary) # Total pixels

            # Calculate Dice score for batch
            intersection = (preds_binary * y).sum()
            dice_batch = (2. * intersection) / (preds_binary.sum() + y.sum() + 1e-6) # Add smooth
            dice_score_val += dice_batch.item()


    avg_loss = total_loss / len(loader)
    pixel_accuracy = num_correct / num_pixels * 100
    avg_dice_score = dice_score_val / len(loader)
    
    print(f"Validation: Got {num_correct}/{num_pixels} with acc {pixel_accuracy:.2f}%")
    print(f"Validation Dice score: {avg_dice_score:.4f}")
    print(f"Validation Avg Loss: {avg_loss:.4f}")
    model.train() # Set model back to train mode
    return avg_loss, pixel_accuracy, avg_dice_score

# --- Training Function ---
def train_fn(loader, model, optimizer, loss_fn, scaler=None): # Add scaler for mixed precision
    loop = loader # tqdm(loader, leave=True) # Consider using tqdm for progress bar
    mean_loss = []

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE).float() # Ensure targets are float

        # Forward
        if scaler: # Mixed precision
            with torch.cuda.amp.autocast():
                predictions = model(data)
                loss = loss_fn(predictions, targets)
        else: # Standard precision
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # Backward
        optimizer.zero_grad()
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        mean_loss.append(loss.item())
        # loop.set_postfix(loss=loss.item()) # If using tqdm

        if batch_idx % 50 == 0: # Print loss every 50 batches
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx}/{len(loader)}], Loss: {loss.item():.4f}")

    print(f"Mean loss for epoch: {sum(mean_loss)/len(mean_loss):.4f}")


# --- Main Function ---
if __name__ == "__main__":
    print(f"Using device: {DEVICE}")

    # Calculate bottleneck image size for the model
    # Assuming 4 pooling layers for features=[64, 128, 256, 512]
    num_pooling_layers = 4 # Adjust if your 'features' list implies different depth
    bottleneck_h = IMAGE_HEIGHT // (2**num_pooling_layers)
    bottleneck_w = IMAGE_WIDTH // (2**num_pooling_layers)

    model = HybridAttentionTransformerUNet(
        in_channels=1,  # Grayscale input
        out_channels=1, # Binary mask output
        features=[64, 128, 256, 512],
        transformer_embed_dim=512,
        transformer_num_heads=8,
        transformer_ff_dim=2048,
        transformer_num_layers=4, # Adjust as needed
        bottleneck_img_size=(bottleneck_h, bottleneck_w)
    ).to(DEVICE)

    # Choose loss function
    # loss_fn = nn.BCEWithLogitsLoss() # Good starting point
    loss_fn = DiceLoss() # Often better for segmentation

    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # Optional: Gradient Scaler for mixed-precision training (if using CUDA)
    scaler = None
    if DEVICE == "cuda":
        scaler = torch.cuda.amp.GradScaler()

    # --- Prepare DataLoaders ---
    # You might want to split your CSV into train/val sets
    # For simplicity, using the whole CSV for both here.
    # In practice, create separate CSVs or split dataframe.
    # Example:
    # df = pd.read_csv(CSV_FILE_PATH)
    # train_df = df.sample(frac=0.8, random_state=42)
    # val_df = df.drop(train_df.index)
    # train_df.to_csv("train_data.csv", index=False)
    # val_df.to_csv("val_data.csv", index=False)
    # Then use "train_data.csv" and "val_data.csv" for VeinDataset

    try:
        train_dataset = VeinDataset(
            csv_file=CSV_FILE_PATH, # Replace with your training CSV if split
            image_dir_root=IMAGE_DIR_ROOT,
            image_column=INPUT_IMAGE_COLUMN,
            mask_column=MASK_COLUMN,
            transform=train_transform,
            is_train=True
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            shuffle=True,
        )

        val_dataset = VeinDataset(
            csv_file=CSV_FILE_PATH, # Replace with your validation CSV if split
            image_dir_root=IMAGE_DIR_ROOT,
            image_column=INPUT_IMAGE_COLUMN,
            mask_column=MASK_COLUMN,
            transform=val_transform,
            is_train=False
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE, # Can often be larger for validation
            num_workers=NUM_WORKERS,
            pin_memory=True,
            shuffle=False,
        )
    except FileNotFoundError:
        print("Exiting due to CSV file not found. Please check paths in the script.")
        exit()
    except Exception as e:
        print(f"Error creating datasets/dataloaders: {e}")
        exit()


    # --- Optional: Load Checkpoint ---
    # start_epoch = 0
    # if os.path.exists("my_checkpoint.pth.tar"):
    #     start_epoch = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
    #     print(f"Resuming training from epoch {start_epoch+1}")

    print("Starting training...")
    best_val_dice = 0.0 # To save the best model based on validation Dice

    for epoch in range(NUM_EPOCHS): # Use start_epoch if resuming
        print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # Perform validation
        val_loss, val_pixel_acc, val_dice = check_accuracy(val_loader, model, loss_fn, device=DEVICE)

        # Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        save_checkpoint(checkpoint, filename=f"checkpoint_epoch_{epoch+1}.pth.tar")

        # Save the best model based on validation dice score
        if val_dice > best_val_dice:
            print(f"New best validation Dice score: {val_dice:.4f} (previous: {best_val_dice:.4f})")
            best_val_dice = val_dice
            save_checkpoint(checkpoint, filename="best_model_checkpoint.pth.tar")

    print("Training finished.")
    print(f"Best validation Dice score achieved: {best_val_dice:.4f}")

