In [None]:
#Finetune with cityscapes

In [26]:
import numpy as np
import torch
import cv2
import os
from torch.utils.data import Dataset, DataLoader
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import time

# Dataset class for Cityscapes
class CityscapesDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        # Collect all image and label file paths
        self.image_files = []
        self.label_files = []

        for city in os.listdir(image_dir):
            city_image_dir = os.path.join(image_dir, city)
            city_label_dir = os.path.join(label_dir, city)

            for file_name in os.listdir(city_image_dir):
                if file_name.endswith('_leftImg8bit.png'):
                    image_path = os.path.join(city_image_dir, file_name)
                    label_file_name = file_name.replace('_leftImg8bit.png', '_gtFine_labelIds.png')
                    label_path = os.path.join(city_label_dir, label_file_name)

                    if os.path.exists(label_path):
                        self.image_files.append(image_path)
                        self.label_files.append(label_path)

        assert len(self.image_files) == len(self.label_files), "Mismatch between images and labels"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_files[idx]
        image = cv2.imread(img_path)[..., ::-1]  # Convert BGR to RGB

        # Load label
        label_path = self.label_files[idx]
        label = cv2.imread(label_path, 0)  # Grayscale label image

        # Resize images and labels
        r = min(1024 / image.shape[1], 1024 / image.shape[0])
        new_size = (int(image.shape[1] * r), int(image.shape[0] * r))
        image = cv2.resize(image, new_size)
        label = cv2.resize(label, new_size, interpolation=cv2.INTER_NEAREST)

        # Generate point prompts
        unique_classes = np.unique(label)
        unique_classes = unique_classes[unique_classes != 0]  # Exclude background
        input_points = []
        masks = []
        for cls in unique_classes:
            mask = (label == cls).astype(np.uint8)
            masks.append(mask)
            coords = np.argwhere(mask > 0)
            if len(coords) == 0:
                continue
            yx = coords[np.random.randint(len(coords))]
            input_points.append([[yx[1], yx[0]]])  # x, y format

        input_points = np.array(input_points)
        input_labels = np.ones((len(input_points), 1))

        sample = {
            'image': image,
            'masks': np.array(masks),
            'input_points': input_points,
            'input_labels': input_labels
        }

        if self.transform:
            sample = self.transform(sample)

        return sample

# Define paths to Cityscapes dataset
image_dir = "segment-anything-2/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train"
label_dir = "segment-anything-2/cityscapes/gtFine_trainvaltest/gtFine/train"

# Create dataset and dataloader
dataset = CityscapesDataset(image_dir, label_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Load SAM2 model
sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_small.pt"  # Path to SAM2 model checkpoint
model_cfg = "sam2_hiera_s.yaml"          # Path to SAM2 model config

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
# Manually move the model to the device
sam2_model.to(device)
predictor = SAM2ImagePredictor(sam2_model)

# Set model to training mode
predictor.model.train()

# Define optimizer and scaler
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()

# Timing the fine-tuning process
start_time = time.time()

# Training loop
num_epochs = 10  # Adjust as needed
mean_iou = 0

for epoch in range(num_epochs):
    for batch_idx, sample in enumerate(dataloader):
        with torch.cuda.amp.autocast():
            image = sample['image'][0].numpy()
            masks = sample['masks'][0].numpy()
            input_points = sample['input_points'][0].numpy()
            input_labels = sample['input_labels'][0].numpy()

            if masks.shape[0] == 0:
                continue  # Skip empty batches

            predictor.set_image(image)

            # Prepare prompts
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_points, input_labels, box=None, mask_logits=None, normalize_coords=True)

            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None)

            # Mask decoder
            batched_mode = unnorm_coords.shape[0] > 1
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=batched_mode,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            # Compute segmentation loss
            gt_mask = torch.tensor(masks.astype(np.float32)).to(device)
            prd_mask = torch.sigmoid(prd_masks[:, 0])
            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log(1 - prd_mask + 1e-5)).mean()

            # Compute IOU and score loss
            inter = (gt_mask * (prd_mask > 0.5)).sum(dim=[1, 2])
            union = gt_mask.sum(dim=[1, 2]) + (prd_mask > 0.5).sum(dim=[1, 2]) - inter
            iou = inter / (union + 1e-5)
            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
            loss = seg_loss + score_loss * 0.05

            # Backpropagation
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Update mean IOU
            mean_iou = mean_iou * 0.99 + 0.01 * iou.mean().item()

            if batch_idx % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}], Loss: {loss.item():.4f}, Mean IOU: {mean_iou:.4f}')

    # Save model checkpoint after each epoch
    torch.save(predictor.model.state_dict(), f'sam2_cityscapes_epoch{epoch+1}.pth')

# Save final model
torch.save(predictor.model.state_dict(), 'sam2_cityscapes_final.pth')

# Calculate and print total time taken for fine-tuning
end_time = time.time()
elapsed_time = end_time - start_time
print(f'Total time taken for fine-tuning: {elapsed_time:.2f} seconds')


Epoch [1/10], Batch [0], Loss: 0.3764, Mean IOU: 0.0033
Epoch [1/10], Batch [10], Loss: 0.4723, Mean IOU: 0.0236
Epoch [1/10], Batch [20], Loss: 0.3620, Mean IOU: 0.0456
Epoch [1/10], Batch [30], Loss: 0.3803, Mean IOU: 0.0648
Epoch [1/10], Batch [40], Loss: 0.5344, Mean IOU: 0.0846
Epoch [1/10], Batch [50], Loss: 0.2230, Mean IOU: 0.1022
Epoch [1/10], Batch [60], Loss: 0.4223, Mean IOU: 0.1199
Epoch [1/10], Batch [70], Loss: 0.3065, Mean IOU: 0.1388
Epoch [1/10], Batch [80], Loss: 0.2776, Mean IOU: 0.1529
Epoch [1/10], Batch [90], Loss: 0.2373, Mean IOU: 0.1664
Epoch [1/10], Batch [100], Loss: 0.2437, Mean IOU: 0.1804
Epoch [1/10], Batch [110], Loss: 0.1825, Mean IOU: 0.1906
Epoch [1/10], Batch [120], Loss: 0.1596, Mean IOU: 0.2025
Epoch [1/10], Batch [130], Loss: 0.2200, Mean IOU: 0.2122
Epoch [1/10], Batch [140], Loss: 0.1588, Mean IOU: 0.2244
Epoch [1/10], Batch [150], Loss: 0.1520, Mean IOU: 0.2317
Epoch [1/10], Batch [160], Loss: 0.1573, Mean IOU: 0.2411
Epoch [1/10], Batch [170]