## 1. Import the required libraries

In [1]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import logging
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from segment_anything import sam_model_registry
from torchvision import transforms
from tqdm import tqdm
import torch.optim.lr_scheduler as lr_scheduler
from lora import LoRA_sam  

# Set random seed
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

## Fine-tune the image encoder of SAM
### 1. Define the SAM model

In [2]:
model_type = 'vit_h'
checkpoint = 'weights/sam_vit_h_4b8939.pth'
device = 'cuda:0'

# Load the SAM model and initialize LoRA
sam_model = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model.to(device)

# Initialize the LoRA_sam model
r = 32  # Rank of LoRA
lora_sam_model = LoRA_sam(sam_model, rank=r)
lora_sam_model.to(device)


  state_dict = torch.load(f)


LoRA_sam(
  (sam): Sam(
    (image_encoder): ImageEncoderViT(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
      )
      (blocks): ModuleList(
        (0-31): 32 x Block(
          (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): LoRA_qkv(
              (qkv): Linear(in_features=1280, out_features=3840, bias=True)
              (linear_a_q): Linear(in_features=1280, out_features=32, bias=False)
              (linear_b_q): Linear(in_features=32, out_features=1280, bias=False)
              (linear_a_v): Linear(in_features=1280, out_features=32, bias=False)
              (linear_b_v): Linear(in_features=32, out_features=1280, bias=False)
            )
            (proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (lin1): Linear(in_fe

### 2. Data loading

In [3]:
# Read the training and validation set lists
def read_split_files(file_path):
    with open(file_path, 'r') as f:
        file_names = f.read().strip().split('\n')
    return file_names

# Dataset loading
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, sam_model, file_list, mask_size=(256, 256), device='cpu'):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.sam_model = sam_model
        self.mask_size = mask_size
        self.device = device
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png') and f.replace('.png', '') in file_list]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Read image
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_file)
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_NEAREST)

        # Read mask
        mask_file = image_file.replace('.png', '.png')
        mask_path = os.path.join(self.mask_dir, mask_file)
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = cv2.resize(mask, self.mask_size, interpolation=cv2.INTER_NEAREST)

        # Convert to torch tensor
        input_image_torch = torch.as_tensor(image, dtype=torch.float32).to(self.device)
        input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()  # [C, H, W]

        # Preprocessing step for SAM model
        input_image = self.sam_model.preprocess(input_image_torch.to(self.device))

        # Convert mask to torch tensor
        mask = torch.as_tensor(mask, dtype=torch.long).to(self.device)  # Mask is single-channel

        return input_image, mask

# Create dataset instances for training and validation sets
# Set paths
image_dir = 'datasets/images'
mask_dir = 'datasets/masks'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Read file name lists
train_files = read_split_files('datasets/train.txt')
val_files = read_split_files('datasets/val.txt')

# Create dataset and data loader for training and validation sets
train_dataset = SegmentationDataset(image_dir, mask_dir, sam_model, train_files, device=device)
val_dataset = SegmentationDataset(image_dir, mask_dir, sam_model, val_files, device=device)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Test the data loader
for i, (images, masks) in enumerate(train_loader):
    print(f'Train Batch {i}:')
    print(f'Images shape: {images.shape}')  # Should be [B, C, H, W]
    print(f'Masks shape: {masks.shape}')    # Should be [B, H, W]
    print(f'Mask unique values: {torch.unique(masks)}')  # Output unique values in the mask
    break

for i, (images, masks) in enumerate(val_loader):
    print(f'Val Batch {i}:')
    print(f'Images shape: {images.shape}')  # Should be [B, C, H, W]
    print(f'Masks shape: {masks.shape}')    # Should be [B, H, W]
    print(f'Mask unique values: {torch.unique(masks)}')  # Output unique values in the mask


Train Batch 0:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask unique values: tensor([0, 1], device='cuda:0')
Val Batch 0:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask unique values: tensor([0, 1], device='cuda:0')
Val Batch 1:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask unique values: tensor([0, 1], device='cuda:0')
Val Batch 2:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask unique values: tensor([0, 1], device='cuda:0')
Val Batch 3:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask unique values: tensor([0, 1], device='cuda:0')
Val Batch 4:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask unique values: tensor([0, 1], device='cuda:0')
Val Batch 5:
Images shape: torch.Size([1, 3, 1024, 1024])
Masks shape: torch.Size([1, 256, 256])
Mask un

### Definition of Contrastive Center Loss

In [4]:
class ContrastiveCenterLoss(nn.Module):
    """Contrastive Center Loss.
    
    This loss combines the concepts of Center Loss and Contrastive Loss,
    considering both intra-class compactness and inter-class separability.
    
    Parameters:
        num_classes (int): Number of classes.
        feat_dim (int): Dimension of features.
        use_gpu (bool): Whether to use GPU.
        lambda_c (float): Weight of the center loss part.
    """
    def __init__(self, num_classes=2, feat_dim=256, use_gpu=True, lambda_c=1.0):
        super(ContrastiveCenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu
        self.lambda_c = lambda_c

        # Initialize class centers
        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        """
        Forward propagation function.
        
        Parameters:
            x: Feature matrix, shape (batch_size, feat_dim).
            labels: Ground truth labels, shape (batch_size).
        """
        batch_size = x.size(0)
        
        # Compute distance between each feature and all class centers
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2)
        
        # Create a mask for the class labels
        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        # Compute intra-class distances (distance to the correct class center)
        intra_distances = distmat * mask.float()
        intra_distances = intra_distances.sum() / batch_size

        # Compute inter-class distances (distance to incorrect class centers)
        inter_distances = distmat * (~mask).float()
        inter_distances = inter_distances.sum() / (batch_size * (self.num_classes - 1))
        
        # Compute the contrastive center loss
        loss = (self.lambda_c / 2.0) * intra_distances / (inter_distances + 1e-6) / 0.1
        
        return loss

### 4. Model Training

In [5]:
# Define the function to compute loss
def compute_loss(class_logits, masks, upsampled_embedding, alpha, loss_fn, contrastive_center_loss, ce_weight=1.0, center_weight=1.0):
    """
    Compute cross-entropy loss and contrastive center loss, and combine them with given weights.

    Args:
        class_logits (Tensor): Classification results (B, num_classes, 256, 256)
        masks (Tensor): Masks (B, 256, 256)
        upsampled_embedding (Tensor): Upsampled embeddings (B, 256, 256, 256)
        alpha (float): Weight of contrastive center loss
        loss_fn (nn.Module): Cross-entropy loss function
        contrastive_center_loss (ContrastiveCenterLoss): Instance of contrastive center loss function
        ce_weight (float): Weight of cross-entropy loss
        center_weight (float): Weight of contrastive center loss

    Returns:
        Tensor: Total loss
        Tensor: Cross-entropy loss value
        Tensor: Contrastive center loss value
    """
    # Compute cross-entropy loss
    loss_ce = loss_fn(class_logits, masks.long())
    
    # Compute contrastive center loss
    loss_cent = contrastive_center_loss(upsampled_embedding.view(-1, 256), masks.view(-1)) * alpha
    
    # Total loss
    total_loss = ce_weight * loss_ce + center_weight * loss_cent
    
    return total_loss, loss_ce.item(), loss_cent.item()

# Configure logger
logging.basicConfig(filename='logs/best_model_ce_cocenter_lora.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Define cross-entropy loss function
loss_fn = nn.CrossEntropyLoss()

# Define a custom model with only one convolutional layer
class FeatureMapper(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureMapper, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        x = self.conv1(x)
        return x
    
# Instantiate contrastive center loss and the custom model
contrastive_center_loss = ContrastiveCenterLoss(num_classes=2, feat_dim=256, use_gpu=torch.cuda.is_available())
model = FeatureMapper(in_channels=256, out_channels=2)
model.to(device)

# Freeze all SAM model parameters, unfreeze only LoRA layers and custom convolution layer
for param in lora_sam_model.sam.parameters():
    param.requires_grad = False

for layer in lora_sam_model.A_weights + lora_sam_model.B_weights:
    for param in layer.parameters():
        param.requires_grad = True

for param in model.parameters():
    param.requires_grad = True
    
# Optimizer
optimizer = torch.optim.Adam(
    list(filter(lambda p: p.requires_grad, lora_sam_model.parameters())) + list(model.parameters()),
    lr=1e-4,
    weight_decay=1e-4
)

# Define a separate optimizer for contrastive center loss
optimizer_centloss = torch.optim.Adam(contrastive_center_loss.parameters(), lr=0.5)

# Training parameters
num_epochs = 100
best_val_loss = float('inf')
best_epoch = 0
checkpoint_path = 'logs/best_model_ce_lora.pth'  # Replace with actual path
alpha = 0.5 # Weight for center loss

# Define Warmup + custom cosine annealing learning rate scheduler
warmup_epochs = 10  # Number of epochs for warmup
min_lr_factor = 0.01  # Minimum learning rate is 1% of the maximum

def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return float(epoch / warmup_epochs)
    else:
        cosine_decay = 0.5 * (1 + torch.cos(torch.tensor(epoch - warmup_epochs) * torch.pi / (num_epochs - warmup_epochs)))
        return float(min_lr_factor + (1 - min_lr_factor) * cosine_decay)
    
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
scheduler_centloss = lr_scheduler.LambdaLR(optimizer_centloss, lr_lambda=lr_lambda)

# Training loop
for epoch in range(num_epochs):
    try:
        # Set model to training mode
        lora_sam_model.train()
        model.train()

        total_loss = 0  # Accumulate total loss for each batch
        total_loss_ce = 0  # Accumulate cross-entropy loss for each batch
        total_loss_cent = 0  # Accumulate contrastive center loss for each batch
        num_batches = 0

        # Training phase
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Train]"):
            images, masks = images.to(device), masks.to(device)

            # Forward pass: Get image embeddings
            image_embedding = lora_sam_model.sam.image_encoder(images)  # B, 256, 64, 64

            # Upsample to (B, 256, 256, 256)
            upsampled_embedding = F.interpolate(image_embedding, size=(256, 256), mode='bilinear', align_corners=False)

            # Process embeddings using the custom model
            class_logits = model(upsampled_embedding)  # B, num_classes, 256, 256

            # Compute total loss, including cross-entropy loss and contrastive center loss
            loss, loss_ce, loss_cent = compute_loss(
                class_logits, masks, upsampled_embedding, alpha, loss_fn, contrastive_center_loss
            )

            # Backpropagation and optimization
            optimizer.zero_grad()
            optimizer_centloss.zero_grad()
            loss.backward()

            # To eliminate alpha's influence on center point updates, multiply by (1./alpha)
            for param in contrastive_center_loss.parameters():
                if param.grad is not None:
                    param.grad.data *= (1. / alpha)

            optimizer.step()
            optimizer_centloss.step()

            # Accumulate losses
            total_loss += loss.item()
            total_loss_ce += loss_ce
            total_loss_cent += loss_cent
            num_batches += 1

        # Update learning rate scheduler
        scheduler.step()
        scheduler_centloss.step()

        # Calculate average loss
        avg_train_loss = total_loss / num_batches
        avg_train_loss_ce = total_loss_ce / num_batches
        avg_train_loss_cent = total_loss_cent / num_batches

        # Set model to evaluation mode
        lora_sam_model.eval()
        model.eval()

        val_loss = 0
        val_loss_ce = 0
        val_loss_cent = 0
        num_val_batches = 0

        with torch.no_grad():  # Disable gradient calculation
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} [Validation]"):
                images, masks = images.to(device), masks.to(device)

                # Forward pass: Get image embeddings
                image_embedding = lora_sam_model.sam.image_encoder(images)  # B, 256, 64, 64

                # Upsample to (B, 256, 256, 256)
                upsampled_embedding = F.interpolate(image_embedding, size=(256, 256), mode='bilinear', align_corners=False)

                # Process embeddings using the custom model
                class_logits = model(upsampled_embedding)  # B, num_classes, 256, 256

                # Compute total loss, including cross-entropy loss and contrastive center loss
                loss, loss_ce, loss_cent = compute_loss(
                    class_logits, masks, upsampled_embedding, alpha, loss_fn, contrastive_center_loss
                )

                # Accumulate losses
                val_loss += loss.item()
                val_loss_ce += loss_ce
                val_loss_cent += loss_cent
                num_val_batches += 1

        avg_val_loss = val_loss / num_val_batches
        avg_val_loss_ce = val_loss_ce / num_val_batches
        avg_val_loss_cent = val_loss_cent / num_val_batches

        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        current_lr_centloss = optimizer_centloss.param_groups[0]['lr']

        # Output and log training/validation losses and current learning rates
        logging.info(f"Epoch [{epoch + 1}/{num_epochs}], Learning Rate: {current_lr:.6f}, Center Loss Learning Rate: {current_lr_centloss:.6f}, "
                     f"Average Train Loss: {avg_train_loss:.4f}, Average Val Loss: {avg_val_loss:.4f}, "
                     f"Train CE Loss: {avg_train_loss_ce:.4f}, Train Center Loss: {avg_train_loss_cent:.4f}, "
                     f"Val CE Loss: {avg_val_loss_ce:.4f}, Val Center Loss: {avg_val_loss_cent:.4f}")

        print(f"Epoch [{epoch + 1}/{num_epochs}], Learning Rate: {current_lr:.6f}, Center Loss Learning Rate: {current_lr_centloss:.6f}, "
              f"Average Train Loss: {avg_train_loss:.4f}, Average Val Loss: {avg_val_loss:.4f}, "
              f"Train CE Loss: {avg_train_loss_ce:.4f}, Train Center Loss: {avg_train_loss_cent:.4f}, "
              f"Val CE Loss: {avg_val_loss_ce:.4f}, Val Center Loss: {avg_val_loss_cent:.4f}")


        # Save the model with the best validation performance
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_epoch = epoch + 1
            # Save LoRA and classifier weights
            lora_sam_model.save_lora_parameters(f'logs/best_lora_cocenter_rank{r}.safetensors')
            torch.save({
                'model_state_dict': model.state_dict(),
            }, checkpoint_path)
            logging.info(f"Best model saved at epoch {best_epoch} with val loss {best_val_loss:.4f}")
            print(f"Best model saved at epoch {best_epoch} with val loss {best_val_loss:.4f}")

    except Exception as e:
        logging.error(f"Exception occurred during epoch {epoch + 1}: {str(e)}")
        print(f"Exception occurred during epoch {epoch + 1}: {str(e)}")
        lora_sam_model.save_lora_parameters(f'logs/error_lora_epoch_{epoch + 1}.safetensors')
        torch.save({
            'model_state_dict': model.state_dict(),
        }, f'logs/error_model_epoch_{epoch + 1}.pth')
        break

logging.info("Training completed")
print("Training completed")


Epoch 1/100 [Train]:   0%|          | 0/95 [00:00<?, ?it/s]

Exception occurred during epoch 1: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 23.63 GiB of which 119.94 MiB is free. Process 111447 has 17.12 GiB memory in use. Including non-PyTorch memory, this process has 6.39 GiB memory in use. Of the allocated memory 5.67 GiB is allocated by PyTorch, and 280.54 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Training completed



