In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from PIL import Image
import cv2
from pathlib import Path
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# Avoid importing numpy directly - use torch operations instead
print("NumPy-free training script - avoiding all numpy operations")
print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")


class iSAIDDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        """
        iSAID Dataset for Mask R-CNN training without NumPy dependencies
        """
        self.root_dir = Path(root_dir)
        self.transforms = transforms
        
        # Paths to different directories
        self.instance_masks_dir = self.root_dir / "Instance_masks" / "images"
        self.semantic_masks_dir = self.root_dir / "Semantic_masks" / "images"
        
        # Check if directories exist
        if not self.instance_masks_dir.exists():
            raise ValueError(f"Instance masks directory not found: {self.instance_masks_dir}")
        if not self.semantic_masks_dir.exists():
            raise ValueError(f"Semantic masks directory not found: {self.semantic_masks_dir}")
        
        # Get all image files
        self.image_files = sorted([f for f in os.listdir(self.instance_masks_dir) 
                                 if f.endswith('.png')])
        
        print(f"Found {len(self.image_files)} images")
        
        # iSAID class mapping (15 classes + background)
        self.classes = [
            'background', 'ship', 'storage_tank', 'baseball_diamond', 
            'tennis_court', 'basketball_court', 'ground_track_field',
            'bridge', 'large_vehicle', 'small_vehicle', 'helicopter',
            'swimming_pool', 'roundabout', 'soccer_ball_field', 
            'plane', 'harbor'
        ]
        
        # RGB to class mapping as nested dictionary
        self.rgb_to_class = {
            0: {0: {0: 0, 63: 1, 127: 9, 191: 10, 255: 11}},  # R=0
            # Add more mappings as needed
        }
        
        # Keep track of failed samples
        self.failed_samples = set()
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Skip failed samples
        if idx in self.failed_samples:
            return self.get_dummy_sample(idx)
        
        try:
            # Get image filename
            img_filename = self.image_files[idx]
            
            # Create a dummy RGB image
            img = Image.new('RGB', (800, 800), color=(128, 128, 128))
            
            # Load instance mask using PIL only
            instance_mask_path = self.instance_masks_dir / img_filename
            if not instance_mask_path.exists():
                raise FileNotFoundError(f"Instance mask not found: {instance_mask_path}")
            
            # Load masks as PIL images and convert to tensors directly
            instance_pil = Image.open(instance_mask_path)
            if instance_pil.mode != 'L':
                instance_pil = instance_pil.convert('L')
            
            # Load semantic mask
            semantic_filename = img_filename.replace('_instance_id_RGB.png', '_instance_color_RGB.png')
            semantic_mask_path = self.semantic_masks_dir / semantic_filename
            
            if not semantic_mask_path.exists():
                raise FileNotFoundError(f"Semantic mask not found: {semantic_mask_path}")
            
            semantic_pil = Image.open(semantic_mask_path)
            if semantic_pil.mode != 'RGB':
                semantic_pil = semantic_pil.convert('RGB')
            
            # Resize image to match mask dimensions
            mask_width, mask_height = instance_pil.size
            img = img.resize((mask_width, mask_height))
            
            # Convert PIL to tensors directly (avoiding numpy)
            instance_tensor = transforms.ToTensor()(instance_pil)
            semantic_tensor = transforms.ToTensor()(semantic_pil)
            
            # Extract instances using pure PyTorch operations
            boxes, labels, masks = self.extract_instances_torch(instance_tensor, semantic_tensor)
            
            # Handle empty detections
            if len(boxes) == 0:
                boxes = torch.zeros((0, 4), dtype=torch.float32)
                labels = torch.zeros((0,), dtype=torch.int64)
                masks = torch.zeros((0, mask_height, mask_width), dtype=torch.uint8)
                area = torch.zeros((0,), dtype=torch.float32)
                iscrowd = torch.zeros((0,), dtype=torch.int64)
            else:
                boxes = torch.stack(boxes)
                labels = torch.tensor(labels, dtype=torch.int64)
                masks = torch.stack(masks)
                area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
                iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
            
            image_id = torch.tensor([idx])
            
            target = {
                "boxes": boxes,
                "labels": labels,
                "masks": masks,
                "image_id": image_id,
                "area": area,
                "iscrowd": iscrowd
            }
            
            if self.transforms:
                img = self.transforms(img)
            else:
                img = transforms.ToTensor()(img)
                
            return img, target
            
        except Exception as e:
            print(f"Error loading sample {idx}: {str(e)}")
            self.failed_samples.add(idx)
            return self.get_dummy_sample(idx)
    
    def get_dummy_sample(self, idx):
        """Return a valid dummy sample"""
        dummy_img = torch.zeros((3, 800, 800), dtype=torch.float32)
        dummy_target = {
            "boxes": torch.zeros((0, 4), dtype=torch.float32),
            "labels": torch.zeros((0,), dtype=torch.int64),
            "masks": torch.zeros((0, 800, 800), dtype=torch.uint8),
            "image_id": torch.tensor([idx]),
            "area": torch.zeros((0,), dtype=torch.float32),
            "iscrowd": torch.zeros((0,), dtype=torch.int64)
        }
        return dummy_img, dummy_target
    
    def extract_instances_torch(self, instance_tensor, semantic_tensor):
        """Extract instances using pure PyTorch operations"""
        try:
            # instance_tensor shape: (1, H, W) or (H, W)
            if instance_tensor.dim() == 3:
                instance_mask = instance_tensor[0]  # Take first channel
            else:
                instance_mask = instance_tensor
            
            # Get unique values using PyTorch
            unique_vals = torch.unique(instance_mask)
            unique_vals = unique_vals[unique_vals > 0]  # Remove background
            
            boxes = []
            labels = []
            masks = []
            
            for instance_id in unique_vals:
                # Create binary mask
                binary_mask = (instance_mask == instance_id).float()
                
                # Skip small masks
                if binary_mask.sum() < 100:
                    continue
                
                # Get bounding box using torch operations
                nonzero_indices = torch.nonzero(binary_mask, as_tuple=False)
                if len(nonzero_indices) == 0:
                    continue
                
                y_coords = nonzero_indices[:, 0]
                x_coords = nonzero_indices[:, 1]
                
                xmin = x_coords.min().float()
                xmax = x_coords.max().float()
                ymin = y_coords.min().float()
                ymax = y_coords.max().float()
                
                # Skip invalid boxes
                if xmax <= xmin or ymax <= ymin or (xmax - xmin) < 5 or (ymax - ymin) < 5:
                    continue
                
                boxes.append(torch.tensor([xmin, ymin, xmax, ymax], dtype=torch.float32))
                
                # Simple class assignment (you can improve this)
                class_label = self.get_class_torch(semantic_tensor, binary_mask)
                labels.append(class_label)
                
                # Convert mask to uint8
                mask_uint8 = (binary_mask * 255).byte()
                masks.append(mask_uint8)
            
            return boxes, labels, masks
            
        except Exception as e:
            print(f"Error extracting instances with torch: {str(e)}")
            return [], [], []
    
    def get_class_torch(self, semantic_tensor, instance_mask):
        """Get class using PyTorch operations"""
        try:
            # semantic_tensor shape: (3, H, W)
            # instance_mask shape: (H, W)
            
            # Get pixels where instance mask is active
            mask_indices = instance_mask > 0
            if not mask_indices.any():
                return 1
            
            # Sample a few pixels from the semantic mask
            semantic_pixels = semantic_tensor[:, mask_indices]  # Shape: (3, N)
            
            # Simple heuristic: use the mean RGB values
            if semantic_pixels.size(1) > 0:
                mean_rgb = semantic_pixels.mean(dim=1)  # Shape: (3,)
                
                # Map RGB to class (simplified)
                r, g, b = mean_rgb[0].item(), mean_rgb[1].item(), mean_rgb[2].item()
                
                # Simple color-based classification
                if r > 0.5:  # Red-ish
                    return 1  # ship
                elif g > 0.5:  # Green-ish
                    return 2  # storage_tank
                elif b > 0.5:  # Blue-ish
                    return 3  # baseball_diamond
                else:
                    return min(int(r * 15) + 1, 15)  # Use red channel for class
            
            return 1  # Default class
            
        except Exception as e:
            return 1


def get_model(num_classes):
    """Create Mask R-CNN model"""
    model = maskrcnn_resnet50_fpn(pretrained=True)
    
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    
    return model


def collate_fn(batch):
    """Custom collate function for DataLoader"""
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return [], []
    return tuple(zip(*batch))


def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for i, (images, targets) in enumerate(tqdm(data_loader, desc=f"Epoch {epoch}")):
        try:
            if len(images) == 0:
                continue
                
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            # Forward pass
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            # Check for invalid loss
            if torch.isnan(losses) or torch.isinf(losses):
                print(f"Invalid loss detected, skipping batch {i}")
                continue
            
            # Backward pass
            optimizer.zero_grad()
            losses.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            running_loss += losses.item()
            num_batches += 1
            
            if i % print_freq == 0:
                print(f"Epoch: {epoch}, Batch: {i}, Loss: {losses.item():.4f}")
                for k, v in loss_dict.items():
                    print(f"  {k}: {v.item():.4f}")
                    
        except Exception as e:
            print(f"Error in batch {i}: {str(e)}")
            continue
    
    return running_loss / max(num_batches, 1)


def evaluate(model, data_loader, device):
    """Evaluate the model"""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for i, (images, targets) in enumerate(tqdm(data_loader, desc="Evaluating")):
            try:
                if len(images) == 0:
                    continue
                    
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                
                if not (torch.isnan(losses) or torch.isinf(losses)):
                    total_loss += losses.item()
                    num_batches += 1
                    
            except Exception as e:
                print(f"Error in evaluation batch {i}: {str(e)}")
                continue
    
    return total_loss / max(num_batches, 1)


def main():
    print("Starting NumPy-free training...")
    
    # Configuration
    config = {
        'data_dir': '/Users/soumendusekharbhattacharjee/Downloads/iSAID_data/train',
        'batch_size': 1,
        'num_epochs': 3,
        'learning_rate': 0.001,
        'weight_decay': 0.0005,
        'num_workers': 0,
        'save_dir': './checkpoints',
        'num_classes': 16
    }
    
    # Create save directory
    os.makedirs(config['save_dir'], exist_ok=True)
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    try:
        # Dataset and DataLoader
        print("Loading dataset...")
        dataset = iSAIDDataset(config['data_dir'], transforms=transform)
        
        if len(dataset) == 0:
            raise ValueError("No valid samples found in dataset")
        
        # Split dataset
        total_size = min(len(dataset), 100)
        train_size = int(0.8 * total_size)
        val_size = total_size - train_size
        
        indices = list(range(total_size))
        train_indices = indices[:train_size]
        val_indices = indices[train_size:total_size]
        
        train_dataset = torch.utils.data.Subset(dataset, train_indices)
        val_dataset = torch.utils.data.Subset(dataset, val_indices)
        
        print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'],
            collate_fn=collate_fn
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            collate_fn=collate_fn
        )
        
        # Model
        print("Loading model...")
        model = get_model(config['num_classes'])
        model.to(device)
        
        # Optimizer
        params = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.SGD(params, lr=config['learning_rate'], 
                             momentum=0.9, weight_decay=config['weight_decay'])
        
        # Learning rate scheduler
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)
        
        print("Starting training...")
        
        # Training loop
        train_losses = []
        val_losses = []
        
        for epoch in range(config['num_epochs']):
            print(f"\n=== Epoch {epoch+1}/{config['num_epochs']} ===")
            
            # Train
            train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
            train_losses.append(train_loss)
            
            # Validate
            val_loss = evaluate(model, val_loader, device)
            val_losses.append(val_loss)
            
            # Update learning rate
            lr_scheduler.step()
            
            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # Save checkpoint
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
            }
            torch.save(checkpoint, f"{config['save_dir']}/maskrcnn_epoch_{epoch}.pth")
            print(f"Checkpoint saved: maskrcnn_epoch_{epoch}.pth")
        
        print("Training completed successfully!")
        print(f"Failed samples: {len(dataset.failed_samples)}")
        
    except Exception as e:
        print(f"Error during training: {str(e)}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

NumPy-free training script - avoiding all numpy operations
PyTorch version: 2.0.1
Torchvision version: 0.15.2
Starting NumPy-free training...
Using device: cpu
Loading dataset...
Found 1411 images
Train samples: 80, Val samples: 20
Loading model...
Starting training...

=== Epoch 1/3 ===


Epoch 0:   1%|▍                                  | 1/80 [00:03<04:27,  3.38s/it]

Epoch: 0, Batch: 0, Loss: 297.9584
  loss_classifier: 3.2347
  loss_box_reg: 0.0282
  loss_mask: 293.4094
  loss_objectness: 1.2649
  loss_rpn_box_reg: 0.0213


Epoch 0:  14%|████▋                             | 11/80 [01:25<08:40,  7.55s/it]

Epoch: 0, Batch: 10, Loss: 459.3501
  loss_classifier: 3.1997
  loss_box_reg: 0.0469
  loss_mask: 455.7702
  loss_objectness: 0.3133
  loss_rpn_box_reg: 0.0199


Epoch 0:  26%|████████▉                         | 21/80 [02:32<06:24,  6.52s/it]

Epoch: 0, Batch: 20, Loss: 23.6007
  loss_classifier: 3.0722
  loss_box_reg: 0.5894
  loss_mask: 10.2417
  loss_objectness: 8.2558
  loss_rpn_box_reg: 1.4417


Epoch 0:  39%|█████████████▏                    | 31/80 [04:04<07:19,  8.97s/it]

Epoch: 0, Batch: 30, Loss: 8.7695
  loss_classifier: 2.8151
  loss_box_reg: 0.7151
  loss_mask: -2.4772
  loss_objectness: 6.5069
  loss_rpn_box_reg: 1.2096


Epoch 0:  51%|█████████████████▍                | 41/80 [05:07<04:05,  6.31s/it]

Epoch: 0, Batch: 40, Loss: -31.2400
  loss_classifier: 2.7692
  loss_box_reg: 0.4237
  loss_mask: -41.1809
  loss_objectness: 6.3051
  loss_rpn_box_reg: 0.4429


Epoch 0:  64%|█████████████████████▋            | 51/80 [05:50<01:52,  3.88s/it]

Epoch: 0, Batch: 50, Loss: -113.2122
  loss_classifier: 2.7614
  loss_box_reg: 0.0762
  loss_mask: -122.1666
  loss_objectness: 5.7960
  loss_rpn_box_reg: 0.3208


Epoch 0:  76%|█████████████████████████▉        | 61/80 [06:42<01:36,  5.07s/it]

Epoch: 0, Batch: 60, Loss: -125.5990
  loss_classifier: 2.7308
  loss_box_reg: 0.6818
  loss_mask: -134.7440
  loss_objectness: 4.9005
  loss_rpn_box_reg: 0.8319


Epoch 0:  89%|██████████████████████████████▏   | 71/80 [07:57<01:03,  7.08s/it]

Epoch: 0, Batch: 70, Loss: 14.8977
  loss_classifier: 2.6873
  loss_box_reg: 0.5076
  loss_mask: 6.8761
  loss_objectness: 4.1534
  loss_rpn_box_reg: 0.6734


Epoch 0: 100%|██████████████████████████████████| 80/80 [08:57<00:00,  6.72s/it]
Evaluating:   5%|█▌                              | 1/20 [00:12<03:49, 12.06s/it]

Error in evaluation batch 0: 'list' object has no attribute 'values'


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import warnings
warnings.filterwarnings("ignore")

# Ensure numpy is properly imported
try:
    import numpy as np
    print(f"NumPy version: {np.__version__}")
except ImportError as e:
    print(f"NumPy import error: {e}")
    raise ImportError("NumPy is required but not available. Please install it with: pip install numpy")


class iSAIDDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        """
        iSAID Dataset for Mask R-CNN validation
        """
        self.root_dir = Path(root_dir)
        self.transforms = transforms
        
        # Paths to different directories
        self.instance_masks_dir = self.root_dir / "Instance_masks" / "images"
        self.semantic_masks_dir = self.root_dir / "Semantic_masks" / "images"
        
        # Check if directories exist
        if not self.instance_masks_dir.exists():
            raise ValueError(f"Instance masks directory not found: {self.instance_masks_dir}")
        if not self.semantic_masks_dir.exists():
            raise ValueError(f"Semantic masks directory not found: {self.semantic_masks_dir}")
        
        # Get all image files
        self.image_files = sorted([f for f in os.listdir(self.instance_masks_dir) 
                                 if f.endswith('.png')])
        
        # iSAID class mapping (15 classes + background)
        self.classes = [
            'background', 'ship', 'storage_tank', 'baseball_diamond', 
            'tennis_court', 'basketball_court', 'ground_track_field',
            'bridge', 'large_vehicle', 'small_vehicle', 'helicopter',
            'swimming_pool', 'roundabout', 'soccer_ball_field', 
            'plane', 'harbor'
        ]
        
        # Class colors for visualization
        self.class_colors = [
            (0, 0, 0),        # background (black)
            (255, 0, 0),      # ship (red)
            (0, 255, 0),      # storage_tank (green)
            (0, 0, 255),      # baseball_diamond (blue)
            (255, 255, 0),    # tennis_court (yellow)
            (255, 0, 255),    # basketball_court (magenta)
            (0, 255, 255),    # ground_track_field (cyan)
            (128, 0, 0),      # bridge (dark red)
            (0, 128, 0),      # large_vehicle (dark green)
            (0, 0, 128),      # small_vehicle (dark blue)
            (128, 128, 0),    # helicopter (olive)
            (128, 0, 128),    # swimming_pool (purple)
            (0, 128, 128),    # roundabout (teal)
            (255, 128, 0),    # soccer_ball_field (orange)
            (255, 0, 128),    # plane (pink)
            (128, 255, 0),    # harbor (lime)
        ]
        
        # Create RGB to class mapping
        self.rgb_to_class = self.create_rgb_mapping()
        
    def create_rgb_mapping(self):
        """Create RGB to class index mapping for iSAID"""
        rgb_mapping = {
            (0, 0, 0): 0,        # background
            (0, 0, 63): 1,       # ship
            (0, 63, 63): 2,      # storage_tank
            (0, 63, 0): 3,       # baseball_diamond
            (0, 63, 127): 4,     # tennis_court
            (0, 63, 191): 5,     # basketball_court
            (0, 63, 255): 6,     # ground_track_field
            (0, 127, 63): 7,     # bridge
            (0, 127, 127): 8,    # large_vehicle
            (0, 0, 127): 9,      # small_vehicle
            (0, 0, 191): 10,     # helicopter
            (0, 0, 255): 11,     # swimming_pool
            (0, 191, 127): 12,   # roundabout
            (0, 127, 191): 13,   # soccer_ball_field
            (0, 127, 255): 14,   # plane
            (0, 100, 155): 15,   # harbor
        }
        return rgb_mapping
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        try:
            # Get image filename
            img_filename = self.image_files[idx]
            base_name = img_filename.replace('_instance_id_RGB.png', '')
            
            # Create a more realistic dummy RGB image
            # In practice, load your actual RGB images here
            img = self.create_dummy_image(base_name)
            
            # Load instance mask
            instance_mask_path = self.instance_masks_dir / img_filename
            instance_mask = np.array(Image.open(instance_mask_path))
            
            # Load semantic mask
            semantic_filename = img_filename.replace('_instance_id_RGB.png', '_instance_color_RGB.png')
            semantic_mask_path = self.semantic_masks_dir / semantic_filename
            semantic_mask = np.array(Image.open(semantic_mask_path))
            
            # Resize image to match mask dimensions
            mask_height, mask_width = instance_mask.shape[:2]
            img = img.resize((mask_width, mask_height))
            
            # Extract ground truth instances
            boxes, labels, masks = self.extract_instances(instance_mask, semantic_mask)
            
            # Handle empty detections
            if len(boxes) == 0:
                boxes = torch.zeros((0, 4), dtype=torch.float32)
                labels = torch.zeros((0,), dtype=torch.int64)
                masks = torch.zeros((0, mask_height, mask_width), dtype=torch.uint8)
                area = torch.zeros((0,), dtype=torch.float32)
                iscrowd = torch.zeros((0,), dtype=torch.int64)
            else:
                boxes = torch.as_tensor(boxes, dtype=torch.float32)
                labels = torch.as_tensor(labels, dtype=torch.int64)
                masks = torch.as_tensor(masks, dtype=torch.uint8)
                area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
                iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
            
            image_id = torch.tensor([idx])
            
            target = {
                "boxes": boxes,
                "labels": labels,
                "masks": masks,
                "image_id": image_id,
                "area": area,
                "iscrowd": iscrowd
            }
            
            # Store original image for visualization
            original_img = np.array(img)
            
            if self.transforms:
                img = self.transforms(img)
            else:
                img = transforms.ToTensor()(img)
                
            return img, target, original_img, img_filename
            
        except Exception as e:
            print(f"Error loading sample {idx}: {str(e)}")
            # Return dummy data
            dummy_img = torch.zeros((3, 800, 800), dtype=torch.float32)
            dummy_target = {
                "boxes": torch.zeros((0, 4), dtype=torch.float32),
                "labels": torch.zeros((0,), dtype=torch.int64),
                "masks": torch.zeros((0, 800, 800), dtype=torch.uint8),
                "image_id": torch.tensor([idx]),
                "area": torch.zeros((0,), dtype=torch.float32),
                "iscrowd": torch.zeros((0,), dtype=torch.int64)
            }
            dummy_original = np.zeros((800, 800, 3), dtype=np.uint8)
            return dummy_img, dummy_target, dummy_original, f"dummy_{idx}.png"
    
    def create_dummy_image(self, base_name):
        """Create a more realistic dummy image"""
        # Create a gradient background
        width, height = 800, 800
        img = Image.new('RGB', (width, height))
        pixels = img.load()
        
        for i in range(width):
            for j in range(height):
                # Create a simple gradient pattern
                r = int(100 + 50 * np.sin(i * 0.01))
                g = int(120 + 30 * np.cos(j * 0.01))
                b = int(140 + 20 * np.sin((i + j) * 0.005))
                pixels[i, j] = (r, g, b)
        
        return img
    
    def extract_instances(self, instance_mask, semantic_mask):
        """Extract individual instances from masks"""
        try:
            if len(instance_mask.shape) == 3:
                instance_mask = instance_mask[:, :, 0]
            
            unique_instances = np.unique(instance_mask)
            unique_instances = unique_instances[unique_instances > 0]
            
            boxes = []
            labels = []
            masks = []
            
            for instance_id in unique_instances:
                instance_binary_mask = (instance_mask == instance_id).astype(np.uint8)
                
                if np.sum(instance_binary_mask) < 100:
                    continue
                    
                pos = np.where(instance_binary_mask)
                if len(pos[0]) == 0:
                    continue
                    
                xmin = int(np.min(pos[1]))
                xmax = int(np.max(pos[1]))
                ymin = int(np.min(pos[0]))
                ymax = int(np.max(pos[0]))
                
                if xmax <= xmin or ymax <= ymin or (xmax - xmin) < 5 or (ymax - ymin) < 5:
                    continue
                
                boxes.append([xmin, ymin, xmax, ymax])
                
                class_label = self.get_class_from_semantic_mask(semantic_mask, instance_binary_mask)
                labels.append(class_label)
                masks.append(instance_binary_mask)
            
            return boxes, labels, masks
            
        except Exception as e:
            print(f"Error extracting instances: {str(e)}")
            return [], [], []
    
    def get_class_from_semantic_mask(self, semantic_mask, instance_mask):
        """Get class label from semantic mask"""
        try:
            masked_semantic = semantic_mask[instance_mask > 0]
            
            if len(masked_semantic) == 0:
                return 1
            
            if len(semantic_mask.shape) == 3:
                rgb_values = masked_semantic.reshape(-1, 3)
                unique_colors, counts = np.unique(rgb_values, axis=0, return_counts=True)
                most_common_color = tuple(unique_colors[np.argmax(counts)])
                return self.rgb_to_class.get(most_common_color, 1)
            else:
                unique_values, counts = np.unique(masked_semantic, return_counts=True)
                most_common_value = unique_values[np.argmax(counts)]
                return min(int(most_common_value), 15)
                
        except Exception as e:
            return 1


def get_model(num_classes, checkpoint_path=None):
    """Load trained Mask R-CNN model"""
    model = maskrcnn_resnet50_fpn(pretrained=True)
    
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
    else:
        print("No checkpoint provided, using pre-trained weights only")
    
    return model


def collate_fn(batch):
    """Custom collate function"""
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return [], [], [], []
    
    images, targets, original_imgs, filenames = zip(*batch)
    return list(images), list(targets), list(original_imgs), list(filenames)


def safe_tensor_to_numpy(tensor):
    """Safely convert tensor to numpy array"""
    try:
        if isinstance(tensor, torch.Tensor):
            return tensor.detach().cpu().numpy()
        elif isinstance(tensor, np.ndarray):
            return tensor
        else:
            return np.array(tensor)
    except Exception as e:
        print(f"Error converting tensor to numpy: {e}")
        # Fallback method
        try:
            return np.array(tensor.detach().cpu().tolist())
        except:
            return np.array([])


def visualize_predictions(model, data_loader, device, num_samples=5, confidence_threshold=0.5, save_dir='./validation_results'):
    """Visualize model predictions on validation set"""
    model.eval()
    
    os.makedirs(save_dir, exist_ok=True)
    
    # Class names and colors
    classes = [
        'background', 'ship', 'storage_tank', 'baseball_diamond', 
        'tennis_court', 'basketball_court', 'ground_track_field',
        'bridge', 'large_vehicle', 'small_vehicle', 'helicopter',
        'swimming_pool', 'roundabout', 'soccer_ball_field', 
        'plane', 'harbor'
    ]
    
    class_colors = [
        (0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
        (255, 0, 255), (0, 255, 255), (128, 0, 0), (0, 128, 0), (0, 0, 128),
        (128, 128, 0), (128, 0, 128), (0, 128, 128), (255, 128, 0), (255, 0, 128), (128, 255, 0)
    ]
    
    sample_count = 0
    
    with torch.no_grad():
        for batch_idx, (images, targets, original_imgs, filenames) in enumerate(data_loader):
            if sample_count >= num_samples:
                break
                
            if len(images) == 0:
                continue
            
            # Move to device
            images_gpu = [img.to(device) for img in images]
            
            # Get predictions
            predictions = model(images_gpu)
            
            # Process each image in the batch
            for i in range(len(images)):
                if sample_count >= num_samples:
                    break
                
                # Get data
                original_img = original_imgs[i]
                filename = filenames[i]
                target = targets[i]
                pred = predictions[i]
                
                # Create visualization
                fig, axes = plt.subplots(2, 2, figsize=(15, 15))
                fig.suptitle(f'Sample: {filename}', fontsize=16)
                
                # 1. Original Image
                axes[0, 0].imshow(original_img)
                axes[0, 0].set_title('Original Image')
                axes[0, 0].axis('off')
                
                # 2. Ground Truth
                gt_img = original_img.copy()
                gt_boxes = safe_tensor_to_numpy(target['boxes'])
                gt_labels = safe_tensor_to_numpy(target['labels'])
                gt_masks = safe_tensor_to_numpy(target['masks'])
                
                # Draw ground truth
                axes[0, 1].imshow(gt_img)
                for j, (box, label) in enumerate(zip(gt_boxes, gt_labels)):
                    if label > 0:  # Skip background
                        x1, y1, x2, y2 = box
                        rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=2, edgecolor='red', facecolor='none')
                        axes[0, 1].add_patch(rect)
                        axes[0, 1].text(x1, y1-5, f'GT: {classes[label]}', 
                                       color='red', fontsize=8, fontweight='bold')
                
                axes[0, 1].set_title(f'Ground Truth ({len(gt_boxes)} objects)')
                axes[0, 1].axis('off')
                
                # 3. Predictions (Boxes)
                pred_img = original_img.copy()
                pred_boxes = safe_tensor_to_numpy(pred['boxes'])
                pred_labels = safe_tensor_to_numpy(pred['labels'])
                pred_scores = safe_tensor_to_numpy(pred['scores'])
                
                # Filter by confidence
                valid_preds = pred_scores > confidence_threshold
                pred_boxes = pred_boxes[valid_preds]
                pred_labels = pred_labels[valid_preds]
                pred_scores = pred_scores[valid_preds]
                
                axes[1, 0].imshow(pred_img)
                for j, (box, label, score) in enumerate(zip(pred_boxes, pred_labels, pred_scores)):
                    if label > 0:  # Skip background
                        x1, y1, x2, y2 = box
                        rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=2, edgecolor='blue', facecolor='none')
                        axes[1, 0].add_patch(rect)
                        axes[1, 0].text(x1, y1-5, f'{classes[label]}: {score:.2f}', 
                                       color='blue', fontsize=8, fontweight='bold')
                
                axes[1, 0].set_title(f'Predictions ({len(pred_boxes)} objects, conf>{confidence_threshold})')
                axes[1, 0].axis('off')
                
                # 4. Masks Overlay
                mask_overlay = original_img.copy().astype(np.float32)
                pred_masks = safe_tensor_to_numpy(pred['masks'])
                
                # Handle empty predictions
                if len(pred_masks) > 0 and len(valid_preds) > 0:
                    pred_masks = pred_masks[valid_preds]
                else:
                    pred_masks = np.array([])
                
                # Overlay predicted masks
                for j, (mask, label) in enumerate(zip(pred_masks, pred_labels)):
                    if label > 0:
                        mask_binary = (mask[0] > 0.5).astype(np.uint8)
                        color = np.array(class_colors[label])
                        
                        # Create colored mask
                        colored_mask = np.zeros_like(mask_overlay)
                        colored_mask[mask_binary == 1] = color
                        
                        # Blend with original image
                        mask_overlay = mask_overlay * 0.7 + colored_mask * 0.3
                
                mask_overlay = np.clip(mask_overlay, 0, 255).astype(np.uint8)
                axes[1, 1].imshow(mask_overlay)
                axes[1, 1].set_title('Predicted Masks Overlay')
                axes[1, 1].axis('off')
                
                # Add legend
                legend_elements = []
                for label in np.unique(pred_labels):
                    if label > 0:
                        color = np.array(class_colors[label]) / 255.0
                        legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', 
                                                        markerfacecolor=color, markersize=10, 
                                                        label=classes[label]))
                
                if legend_elements:
                    axes[1, 1].legend(handles=legend_elements, loc='upper right', fontsize=8)
                
                plt.tight_layout()
                
                # Save visualization
                save_path = os.path.join(save_dir, f'validation_{sample_count:03d}_{filename}')
                plt.savefig(save_path, dpi=150, bbox_inches='tight')
                plt.show()
                
                # Print statistics
                print(f"\nSample {sample_count + 1}: {filename}")
                print(f"Ground Truth: {len(gt_boxes)} objects")
                print(f"Predictions: {len(pred_boxes)} objects (conf > {confidence_threshold})")
                
                if len(pred_boxes) > 0:
                    print("Predicted classes:")
                    for label, score in zip(pred_labels, pred_scores):
                        print(f"  - {classes[label]}: {score:.3f}")
                
                sample_count += 1


def main():
    # Check dependencies first
    try:
        print("Checking dependencies...")
        print(f"PyTorch version: {torch.__version__}")
        print(f"Torchvision version: {torchvision.__version__}")
        print(f"NumPy version: {np.__version__}")
        print(f"PIL version: {Image.__version__}")
        print(f"Matplotlib version: {plt.matplotlib.__version__}")
    except Exception as e:
        print(f"Dependency check failed: {e}")
        print("Please ensure all required packages are installed:")
        print("pip install torch torchvision numpy pillow matplotlib opencv-python")
        return
    
    # Configuration
    config = {
        'data_dir': '/Users/soumendusekharbhattacharjee/Downloads/iSAID_data/train',
        'checkpoint_path': './checkpoints/maskrcnn_epoch_4.pth',  # Update this path
        'batch_size': 1,
        'num_workers': 0,
        'num_classes': 16,
        'num_samples': 5,  # Reduced for testing
        'confidence_threshold': 0.3,  # Lowered threshold
        'save_dir': './validation_results'
    }
    
    # Check if checkpoint exists
    if not os.path.exists(config['checkpoint_path']):
        print(f"Warning: Checkpoint not found at {config['checkpoint_path']}")
        print("Available checkpoints:")
        checkpoint_dir = os.path.dirname(config['checkpoint_path'])
        if os.path.exists(checkpoint_dir):
            for f in os.listdir(checkpoint_dir):
                if f.endswith('.pth'):
                    print(f"  - {os.path.join(checkpoint_dir, f)}")
        else:
            print(f"Checkpoint directory {checkpoint_dir} does not exist")
        
        # Ask user if they want to continue without checkpoint
        response = input("Continue without loading checkpoint? (y/n): ")
        if response.lower() != 'y':
            return
        config['checkpoint_path'] = None
    
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    try:
        # Load dataset
        print("Loading validation dataset...")
        transform = transforms.Compose([transforms.ToTensor()])
        dataset = iSAIDDataset(config['data_dir'], transforms=transform)
        
        # Use a subset for validation
        total_size = min(len(dataset), 20)  # Use last 20 samples as validation
        val_indices = list(range(len(dataset) - total_size, len(dataset)))
        val_dataset = torch.utils.data.Subset(dataset, val_indices)
        
        print(f"Validation samples: {len(val_dataset)}")
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            collate_fn=collate_fn
        )
        
        # Test data loading
        print("Testing data loading...")
        for i, (images, targets, original_imgs, filenames) in enumerate(val_loader):
            print(f"Loaded batch {i}: {len(images)} images")
            if i >= 1:  # Test first 2 batches
                break
        
        # Load model
        print("Loading model...")
        model = get_model(config['num_classes'], config['checkpoint_path'])
        model.to(device)
        
        # Create save directory
        os.makedirs(config['save_dir'], exist_ok=True)
        
        # Visualize predictions
        print("Generating validation visualizations...")
        visualize_predictions(
            model=model,
            data_loader=val_loader,
            device=device,
            num_samples=config['num_samples'],
            confidence_threshold=config['confidence_threshold'],
            save_dir=config['save_dir']
        )
        
        print(f"Validation results saved to: {config['save_dir']}")
        
    except Exception as e:
        print(f"Error during execution: {e}")
        import traceback
        traceback.print_exc()
        print("\nTroubleshooting tips:")
        print("1. Make sure all dependencies are installed")
        print("2. Check if your data directory path is correct")
        print("3. Verify your checkpoint file exists")
        print("4. Try reducing num_samples if running out of memory")


if __name__ == "__main__":
    main()