In [None]:
import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
import json

# Read data
data_dir = r"MapillaryVistasV2//"  # Path to the Mapillary Vistas dataset
version = "v2.0"  # Change to "v1.2" if using the older version

# Load configuration (JSON) that contains label information
with open(f'config_{version}.json') as config_file:
    config = json.load(config_file)
labels = config['labels']

# Prepare dataset list
data = []
image_dir = os.path.join(data_dir, "training/images/")
label_dir = os.path.join(data_dir, f"training/{version}/labels/")
for image_name in os.listdir(image_dir):
    image_path = os.path.join(image_dir, image_name)
    label_path = os.path.join(label_dir, image_name.replace(".jpg", ".png"))
    data.append({"image": image_path, "annotation": label_path})

def read_batch(data):
    # Select image
    ent = data[np.random.randint(len(data))]  # Choose random entry
    Img = np.array(Image.open(ent["image"]))[..., :3]  # Read image (ignoring alpha channel if present)
    ann_map = np.array(Image.open(ent["annotation"]))  # Read annotation
    
    # Resize image
    r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])  # Scaling factor
    Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
    ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

    # Get binary masks and points
    inds = np.unique(ann_map)[1:]  # Load all indices, ignore background (assuming it's 0)
    points = []
    masks = []
    for ind in inds:
        mask = (ann_map == ind).astype(np.uint8)  # Make binary mask corresponding to index ind
        masks.append(mask)
        coords = np.argwhere(mask > 0)  # Get all coordinates in mask
        yx = np.array(coords[np.random.randint(len(coords))])  # Choose random point/coordinate
        points.append([[yx[1], yx[0]]])
    return Img, np.array(masks), np.array(points), np.ones([len(masks), 1])

# Load model
sam2_checkpoint = "sam2_hiera_small.pt"  # Path to model weight (downloaded from: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
model_cfg = "sam2_hiera_s.yaml"  # Model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")  # Load model
predictor = SAM2ImagePredictor(sam2_model)

# Set training parameters
predictor.model.sam_mask_decoder.train(True)  # Enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True)  # Enable training of prompt encoder
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()  # Mixed precision

# Training loop
for itr in range(100000):
    with torch.cuda.amp.autocast():  # Cast to mixed precision
        image, mask, input_point, input_label = read_batch(data)  # Load data batch
        if mask.shape[0] == 0:
            continue  # Ignore empty batches
        predictor.set_image(image)  # Apply SAM image encoder to the image

        # Prompt encoding
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(unnorm_coords, labels), boxes=None, masks=None)

        # Mask decoder
        batched_mode = unnorm_coords.shape[0] > 1  # Multi-object prediction
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=high_res_features,
        )
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])  # Upscale the masks to the original image resolution

        # Segmentation Loss calculation
        gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])  # Turn logit map to probability map
        seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()  # Cross entropy loss

        # Score loss calculation (Intersection over Union) IOU
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05  # Mix losses

        # Apply backpropagation
        predictor.model.zero_grad()  # Empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mixed precision

        if itr % 1000 == 0:
            torch.save(predictor.model.state_dict(), "model.torch")
            print("Saved model at step:", itr)

        # Display results
        if itr == 0:
            mean_iou = 0
        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
        print("Step:", itr, "Accuracy (IOU):", mean_iou)