In [1]:
import os
os.chdir("..")
print("Current Directory:", os.getcwd())

Current Directory: d:\workspace\iscat


In [2]:
from src.data_processing.dataset import iScatDataset
from src.data_processing.utils import Utils
data_path = os.path.join('data', 'iScat', 'Data', '2024_11_11', 'Metasurface', 'Chip_02')
image_paths,target_paths = Utils.get_data_paths(data_path)

In [3]:
train_dataset = iScatDataset(image_paths[:-1], target_paths[:-1], preload_image=True)
valid_dataset = iScatDataset([image_paths[-1]],[target_paths[-1]],preload_image=True,apply_augmentation=False)

Loading surface images to Memory: 100%|██████████| 4/4 [00:07<00:00,  1.89s/it]
Creating Masks: 100%|██████████| 4/4 [00:00<00:00, 364.62it/s]
Loading surface images to Memory: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]
Creating Masks: 100%|██████████| 1/1 [00:00<00:00, 334.23it/s]


In [17]:
from torch.utils.data import DataLoader, Dataset
def create_dataloaders(train_dataset, test_dataset, batch_size=4):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader
train_loader, val_loader = create_dataloaders(train_dataset, valid_dataset, batch_size=4)

In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torchmetrics
import tqdm

class MulticlassSegmentationDataset(Dataset):
    def __init__(self, images, masks):
        """
        Initialize dataset with images and masks
        
        Args:
            images (torch.Tensor): Tensor of images [N, C, H, W]
            masks (torch.Tensor): Tensor of masks [N, C, H, W]
        """
        self.images = images
        self.masks = masks
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        return self.images[idx], self.masks[idx]

def create_dataloaders(train_images, train_masks, val_images, val_masks, batch_size=4):
    """
    Create train and validation dataloaders
    """
    train_dataset = MulticlassSegmentationDataset(train_images, train_masks)
    val_dataset = MulticlassSegmentationDataset(val_images, val_masks)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    return train_loader, val_loader

class SAM2MulticlassTrainer:
    def __init__(self, model_cfg, checkpoint, num_classes, device='cuda'):
        """
        Initialize SAM2 trainer for multiclass segmentation
        
        Args:
            model_cfg (str): Path to model configuration
            checkpoint (str): Path to pre-trained checkpoint
            num_classes (int): Number of segmentation classes
            device (str): Training device
        """
        self.device = device
        self.num_classes = num_classes
        
        # Build SAM2 model
        self.sam2_model = build_sam2(model_cfg, checkpoint, device=device)
        self.predictor = SAM2ImagePredictor(self.sam2_model)
        
        # IoU and Dice metrics
        self.iou_metric = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)
        self.dice_metric = torchmetrics.Dice(num_classes=num_classes).to(device)
        
    def _prepare_box_prompt(self, image_size):
        """
        Prepare a bounding box prompt covering the entire image
        
        Args:
            image_size (tuple): Height and width of the image
        
        Returns:
            torch.Tensor: Bounding box coordinates
        """
        h, w = image_size
        return torch.tensor([[0, 0, w, h]], dtype=torch.float32).to(self.device)
    
    def train_model(self, train_loader, val_loader, epochs=50, patience=5):
        """
        Train SAM2 model with early stopping
        
        First phase: Only train mask decoder (encoder frozen)
        Second phase: Train entire model
        """
        # Phase 1: Train only mask decoder
        print("Phase 1: Training Mask Decoder")
        self._train_phase(train_loader, val_loader, epochs, patience, freeze_encoder=True)
        
        # Phase 2: Train entire model
        print("Phase 2: Training Entire Model")
        self._train_phase(train_loader, val_loader, epochs, patience, freeze_encoder=False)
    
    def _train_phase(self, train_loader, val_loader, epochs, patience, freeze_encoder=True):
        """
        Perform training phase with early stopping
        """
        # Freeze/unfreeze encoder based on phase
        for param in self.predictor.model.image_encoder.parameters():
            param.requires_grad = not freeze_encoder
        
        # Only train decoder and prompt encoder
        trainable_params = list(self.predictor.model.sam_mask_decoder.parameters()) + \
                           list(self.predictor.model.sam_prompt_encoder.parameters())
        
        optimizer = torch.optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2)
        
        best_val_iou = 0
        patience_counter = 0
        
        for epoch in range(epochs):
            self.predictor.model.train()
            train_losses, train_ious = [], []
            
            for images, masks in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}"):
                images, masks = images.to(self.device), masks.to(self.device)
                optimizer.zero_grad()
                
                # Prepare batch results
                batch_losses, batch_ious = [], []
                
                for img, gt_mask in zip(images, masks):
                    self.predictor.set_image(img.numpy())
                    
                    # Whole image box prompt
                    box_input = self._prepare_box_prompt(img.shape[1:])
                    
                    # Prompt encoding
                    sparse_embeddings, dense_embeddings = self.predictor.model.sam_prompt_encoder(
                        points=None, boxes=box_input, masks=None
                    )
                    
                    # Predict masks
                    low_res_masks, _, _, _ = self.predictor.model.sam_mask_decoder(
                        image_embeddings=self.predictor._features["image_embed"][-1].unsqueeze(0),
                        image_pe=self.predictor.model.sam_prompt_encoder.get_dense_pe(),
                        sparse_prompt_embeddings=sparse_embeddings,
                        dense_prompt_embeddings=dense_embeddings,
                        multimask_output=False
                    )
                    
                    # Postprocess masks
                    prd_masks = self.predictor._transforms.postprocess_masks(
                        low_res_masks, self.predictor._orig_hw[-1]
                    )
                    
                    # Sigmoid and multi-class handling
                    prd_masks = torch.sigmoid(prd_masks)
                    prd_masks_multi = prd_masks.repeat(self.num_classes, 1, 1, 1)
                    
                    # Multi-class segmentation loss
                    seg_loss = nn.functional.cross_entropy(
                        prd_masks_multi,  # predicted masks
                        torch.argmax(gt_mask, dim=0),  # convert one-hot to class indices
                        reduction='mean'
                    )
                    
                    # IoU calculation
                    prd_class_masks = (prd_masks_multi > 0.5).float()
                    iou = self.iou_metric(prd_class_masks, gt_mask)
                    
                    batch_losses.append(seg_loss)
                    batch_ious.append(iou)
                
                # Aggregate batch results
                loss = torch.mean(torch.stack(batch_losses))
                iou = torch.mean(torch.stack(batch_ious))
                
                loss.backward()
                optimizer.step()
                
                train_losses.append(loss.item())
                train_ious.append(iou.item())
            
            # Validation
            self.predictor.model.eval()
            val_losses, val_ious = [], []
            
            with torch.no_grad():
                for images, masks in val_loader:
                    images, masks = images.to(self.device), masks.to(self.device)
                    
                    batch_val_losses, batch_val_ious = [], []
                    
                    for img, gt_mask in zip(images, masks):
                        # Similar prediction logic as training
                        # (omitted for brevity, would mirror training code)
                        pass
                    
                    # Aggregate validation metrics
            
            # Update learning rate and check early stopping
            avg_val_iou = np.mean(val_ious)
            scheduler.step(avg_val_iou)
            
            if avg_val_iou > best_val_iou:
                best_val_iou = avg_val_iou
                patience_counter = 0
                torch.save(self.predictor.model.state_dict(), f'best_model_{"encoder_frozen" if freeze_encoder else "full_model"}.pth')
            else:
                patience_counter += 1
            
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

In [21]:
checkpoint = "src\model\sam2\checkpoints\sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
sam2_model = build_sam2(model_cfg, checkpoint,device='cuda')
predictor = SAM2ImagePredictor(sam2_model)


In [25]:
sam2_model.image_encoder

ImageEncoder(
  (trunk): Hiera(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 96, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
    )
    (blocks): ModuleList(
      (0): MultiScaleBlock(
        (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (attn): MultiScaleAttention(
          (qkv): Linear(in_features=96, out_features=288, bias=True)
          (proj): Linear(in_features=96, out_features=96, bias=True)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (layers): ModuleList(
            (0): Linear(in_features=96, out_features=384, bias=True)
            (1): Linear(in_features=384, out_features=96, bias=True)
          )
          (act): GELU(approximate='none')
        )
      )
      (1): MultiScaleBlock(
        (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=

In [30]:
trainer = SAM2MulticlassTrainer(
    model_cfg=model_cfg , 
    checkpoint=checkpoint , 
    num_classes=3,
    device='cpu'
)
trainer.train_model(train_loader, val_loader)

Phase 1: Training Mask Decoder


Epoch 1:   0%|          | 0/100 [00:00<?, ?it/s]
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x000002A4C63E3B00>
Traceback (most recent call last):
  File "d:\anaconda3\envs\iscat\Lib\site-packages\torch\utils\data\dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "d:\anaconda3\envs\iscat\Lib\site-packages\torch\utils\data\dataloader.py", line 1562, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
                                   ^^^^^^^^^^^^^^^^^^^^
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'


RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "d:\anaconda3\envs\iscat\Lib\site-packages\torch\nn\modules\container.py", line 250, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "d:\anaconda3\envs\iscat\Lib\site-packages\torchvision\transforms\transforms.py", line 277, in forward
            Tensor: Normalized Tensor image.
        """
        return F.normalize(tensor, self.mean, self.std, self.inplace)
               ~~~~~~~~~~~ <--- HERE
  File "d:\anaconda3\envs\iscat\Lib\site-packages\torchvision\transforms\functional.py", line 350, in normalize
        raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")

    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
           ~~~~~~~~~~~~~ <--- HERE
  File "d:\anaconda3\envs\iscat\Lib\site-packages\torchvision\transforms\_functional_tensor.py", line 928, in normalize
    if std.ndim == 1:
        std = std.view(-1, 1, 1)
    return tensor.sub_(mean).div_(std)
           ~~~~~~~~~~~ <--- HERE
RuntimeError: The size of tensor a (224) must match the size of tensor b (3) at non-singleton dimension 0
