In [24]:
import torch
import torch.nn as nn
import torch.quantization
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# ------------------------------
# Step 1: Define Calibration Transforms
# ------------------------------
INPUT_HEIGHT = 224  # Replace with your model's input height
INPUT_WIDTH = 224   # Replace with your model's input width

NORMALIZE_MEAN = [0.485, 0.456, 0.406]  # Replace with your model's mean
NORMALIZE_STD = [0.229, 0.224, 0.225]   # Replace with your model's std

calibration_transforms = transforms.Compose([
    transforms.Resize((INPUT_HEIGHT, INPUT_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
])

# ------------------------------
# Step 2: Define Custom Dataset (Image Only)
# ------------------------------
class ImageOnlyDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = self._get_all_image_files(image_dir)
    
    def _get_all_image_files(self, directory):
        image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff')
        image_files = []
        for root, _, files in os.walk(directory):
            for fname in files:
                if fname.lower().endswith(image_extensions):
                    image_files.append(os.path.join(root, fname))
        return image_files
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        try:
            image = Image.open(image_path).convert('RGB')  # Ensure 3-channel images
            if self.transform:
                image = self.transform(image)
            return image  # No label needed for calibration
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return torch.zeros(3, INPUT_HEIGHT, INPUT_WIDTH)

# ------------------------------
# Step 3: Initialize Calibration DataLoader
# ------------------------------
image_only_val_dir = 'segment-anything-2/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/val'

if not os.path.isdir(image_only_val_dir):
    raise RuntimeError(f"Image directory not found: {image_only_val_dir}")

calibration_dataset = ImageOnlyDataset(
    image_dir=image_only_val_dir,
    transform=calibration_transforms
)

print(f"Found {len(calibration_dataset)} images for calibration.")

calibration_dataloader = DataLoader(
    calibration_dataset,
    batch_size=32,          # Adjust based on your memory constraints
    shuffle=True,           # Shuffle to ensure randomness
    num_workers=0,          # Start with 0 for debugging
    pin_memory=False        # Disable pin_memory for debugging
)

# ------------------------------
# Step 4: Build SAM2ImagePredictor and Fix Checkpoint Loading
# ------------------------------
sam2_checkpoint = "sam2_cityscapes_final.pth"  # Replace with your checkpoint path
model_cfg = "sam2_hiera_s.yaml"  # Replace with your model config

# Determine the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build the SAM2 model and move it to the correct device
sam2_model = build_sam2(model_cfg, None, device=device)
sam2_model = sam2_model.to(device)

# Load the checkpoint directly without assuming 'model' key
checkpoint = torch.load(sam2_checkpoint, map_location=device)
sam2_model.load_state_dict(checkpoint)

# Now create the image predictor (no need to move this to the device)
predictor = SAM2ImagePredictor(sam2_model)

# ------------------------------
# Step 5: Define the SAM2Wrapper Class
# ------------------------------
class SAM2Wrapper(nn.Module):
    def __init__(self, predictor):
        super(SAM2Wrapper, self).__init__()
        self.predictor = predictor

    def forward(self, x):
        outputs = []
        for img in x:
            # Set image and generate embedding
            self.predictor.set_image(img.permute(1, 2, 0).cpu().numpy())
            
            # Perform segmentation using point or box prompts
            input_point = np.array([[500, 375]])  # Example prompt
            input_label = np.array([1])  # Label for foreground point
            
            masks, scores, logits = self.predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True
            )
            
            # Convert mask to tensor
            mask = torch.tensor(masks[0]).unsqueeze(0).float()
            outputs.append(mask)
        
        return torch.stack(outputs)

# ------------------------------
# Step 6: Wrap the Predictor for Calibration
# ------------------------------
model = SAM2Wrapper(predictor).to(device)

# ------------------------------
# Step 7: Custom Quantization Configuration for ConvTranspose2d Layers
# ------------------------------

# Define a per-tensor observer for ConvTranspose2d layers
per_tensor_weight_observer = torch.quantization.MinMaxObserver.with_args(
    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
)

# Define a custom qconfig for ConvTranspose2d layers
convtranspose_qconfig = torch.quantization.QConfig(
    activation=torch.quantization.default_observer,
    weight=per_tensor_weight_observer
)

# Apply the custom qconfig to ConvTranspose2d layers only
for module in model.modules():
    if isinstance(module, torch.nn.ConvTranspose2d):
        module.qconfig = convtranspose_qconfig

# Apply the default qconfig to other layers
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Prepare the model for quantization
torch.quantization.prepare(model, inplace=True)

# ------------------------------
# Step 8: Calibration Loop
# ------------------------------
with torch.no_grad():
    for i, images in enumerate(calibration_dataloader):
        images = images.to(device)
        model(images)  # Perform segmentation and calibration
        if i >= 10:  # Calibrate with 11 batches (0 to 10)
            print(f"Calibrated on batch {i}")
            break

# ------------------------------
# Step 9: Convert to Quantized Model
# ------------------------------
torch.quantization.convert(model, inplace=True)

# ------------------------------
# Step 10: Save the Optimized Model
# ------------------------------
optimized_model_path = 'sam2_cityscapes_optimized.pth'
torch.save(model.state_dict(), optimized_model_path)

print(f"Quantization and pruning completed successfully. Optimized model saved as '{optimized_model_path}'.")


Found 500 images for calibration.


  x = F.scaled_dot_product_attention(
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
  out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
Falling back to all available kernels for scaled_dot_product_attention (which may have a slower speed).
  return forward_call(*args, **kwargs)


Calibrated on batch 10
Quantization and pruning completed successfully. Optimized model saved as 'sam2_cityscapes_optimized.pth'.
