In [None]:
import os
import json
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as T
from seg_head_instance_seg import efficientvit_instance_seg_custom
import pandas as pd
from data_load import get_transforms,CocoSegDataset,collate_fn,instance_loss_fn,save_and_plot_metrics,load_checkpoint
from torch.cuda.amp import autocast, GradScaler
from pycocotools import mask as mask_utils
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as mask_utils
from sklearn.cluster import DBSCAN




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model parameters
num_classes = 2
img_size = 224
batch_size = 2

# Create a dummy input 
dummy_input = torch.randn(batch_size, 3, img_size, img_size)

# efficientvit_m0
model = efficientvit_instance_seg_custom(
    n_classes=num_classes,
    img_size=img_size,
    patch_size=16,
    embed_dim=[64, 128, 192],
    depth=[1, 2, 3],
    num_heads=[4, 4, 4],
    window_size=[7, 7, 7],
    kernels=[5, 5, 5, 5],
    down_ops=[['subsample', 2], ['subsample', 2], ['']],
    distillation=False,
    frozen_stages=0
)

# Evaluation 
model.eval()
with torch.no_grad():
    outputs = model(dummy_input)

seg_logits = outputs["seg_logits"]
embedding = outputs["embedding"]

print("Foreground segmentation logits shape:", seg_logits.shape)  #  [B, n_classes, H, W]
print("Instance embedding shape:", embedding.shape)               #  [B, emb_dim, H, W]
print("Total model parameters:", sum(p.numel() for p in model.parameters()))




Foreground segmentation logits shape: torch.Size([2, 2, 80, 80])
Instance embedding shape: torch.Size([2, 16, 80, 80])
Total model parameters: 3354652


In [None]:
from torchinfo import summary
model.to("cuda" if torch.cuda.is_available() else "cpu")
summary(model, input_size=(2, 3, 224, 224))


Layer (type:depth-idx)                                            Output Shape              Param #
EfficientViTSeg                                                   [2, 16, 80, 80]           --
├─EfficientViT: 1-1                                               [2, 192, 1, 1]            --
│    └─Sequential: 2-1                                            [2, 64, 14, 14]           --
│    │    └─Conv2d_BN: 3-1                                        [2, 8, 112, 112]          (232)
│    │    └─ReLU: 3-2                                             [2, 8, 112, 112]          --
│    │    └─Conv2d_BN: 3-3                                        [2, 16, 56, 56]           (1,184)
│    │    └─ReLU: 3-4                                             [2, 16, 56, 56]           --
│    │    └─Conv2d_BN: 3-5                                        [2, 32, 28, 28]           (4,672)
│    │    └─ReLU: 3-6                                             [2, 32, 28, 28]           --
│    │    └─Conv2d_BN: 3-7      

In [None]:
batch_size = 64  
num_epochs = 100
learning_rate = 2e-4
weight_decay = 1e-4

train_root = '/teamspace/uploads/pig_seg/annotations/train'  
val_root = '/teamspace/uploads/pig_seg/annotations/val'
train_ann_file = '/teamspace/uploads/pig_seg/annotations/train/_annotations.coco.json'
val_ann_file = '/teamspace/uploads/pig_seg/annotations/val/_annotations.coco.json'
train_transform = get_transforms(is_train=True)
val_transform = get_transforms(is_train=False)

train_dataset = CocoSegDataset(root=train_root, ann_file=train_ann_file, transform=train_transform)
val_dataset = CocoSegDataset(root=val_root, ann_file=val_ann_file, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

os.makedirs("checkpoints", exist_ok=True)

# Initialize model without pretrained weights
model = efficientvit_instance_seg_custom(emb_dim=16, n_classes=2).to(device)

# Load the last trained checkpoint
checkpoint_path = 'checkpoints/efficientvit_pig_seg_final.pth'
seg_loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scaler = GradScaler()
start_epoch = load_checkpoint(model, optimizer, checkpoint_path, device)
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0
    total_samples = 0
    for batch_idx, batch in enumerate(train_loader):
        imgs = batch['img'].to(device)
        actual_batch_size = len(imgs)
        gt_bboxes = [bbox.to(device) for bbox in batch['gt_bboxes']]
        gt_labels = [label.to(device) for label in batch['gt_labels']]
        gt_masks = [mask.to(device) for mask in batch['gt_masks']]
        with autocast():
            outputs = model(imgs)
            seg_logits = outputs['seg_logits']
            embeddings = outputs['embedding']
            gt_seg = torch.stack([mask.sum(dim=0).clamp(0, 1).long() for mask in gt_masks]).to(device)
            seg_loss = seg_loss_fn(seg_logits, gt_seg)
            inst_loss = instance_loss_fn(embeddings, gt_masks, actual_batch_size)
            loss = seg_loss + 2.0 * inst_loss
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, "
                      f"seg_loss={seg_loss.item():.4f}, inst_loss={inst_loss.item():.4f}, total_loss={loss.item():.4f}")
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item() * actual_batch_size
        total_samples += actual_batch_size
    avg_loss = total_loss / total_samples if total_samples > 0 else float('nan')
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    model.eval()
    val_coco = COCO(val_ann_file)
    val_preds = []
    with torch.no_grad():
        for batch in val_loader:
            imgs = batch['img'].to(device)
            outputs = model(imgs)
            seg_logits = outputs['seg_logits'].argmax(dim=1)
            embeddings = outputs['embedding']
            for b in range(imgs.size(0)):
                pred_class = seg_logits[b].cpu().numpy()
                mask = pred_class > 0
                if not mask.any():
                    print(f"Batch {b}: No foreground pixels detected, skipping DBSCAN")
                    continue
                emb = embeddings[b].permute(1, 2, 0).reshape(-1, 16).cpu().numpy()
                emb_flat = emb[mask.flatten()]
                if len(emb_flat) < 5:
                    print(f"Batch {b}: Too few foreground pixels ({len(emb_flat)}), skipping DBSCAN")
                    continue
                clustering = DBSCAN(eps=0.5, min_samples=5).fit(emb_flat)
                labels = clustering.labels_
                if (labels == -1).all():
                    print(f"Batch {b}: All points classified as noise by DBSCAN")
                    continue
                mask_flat = mask.flatten()
                mask_flat[mask_flat] = labels + 1
                pred_mask = mask_flat.reshape(80, 80)
                unique_ids = np.unique(pred_mask)
                if len(unique_ids) <= 1:
                    print(f"Batch {b}: No instances after clustering")
                    continue
                for inst_id in unique_ids:
                    if inst_id == 0: continue
                    inst_mask = (pred_mask == inst_id).astype(np.uint8)
                    rle = mask_utils.encode(np.asfortranarray(inst_mask))
                    rle['counts'] = rle['counts'].decode('utf-8')
                    val_preds.append({
                        'image_id': batch['image_id'][b],
                        'category_id': 1,
                        'segmentation': rle,
                        'score': 0.95
                    })
                print(f"Batch {b}: Added {len(unique_ids)-1} instance predictions")

    if val_preds:
        coco_dt = val_coco.loadRes(val_preds)
        coco_eval = COCOeval(val_coco, coco_dt, 'segm')
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()
        mAP = coco_eval.stats[0]
        print(f"Validation mAP: {mAP:.4f}")
    else:
        print("No predictions generated for validation set. Setting mAP to 0.")
        mAP = 0.0

    save_and_plot_metrics(epoch, avg_loss, mAP)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, f"checkpoints/model_epoch_{epoch+1}.pth")

torch.save({
    'epoch': num_epochs - 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, "checkpoints/efficientvit_pig_seg_final.pth")
print("Training completed and model saved!")

loading annotations into memory...
Done (t=2.83s)
creating index...
index created!
loading annotations into memory...
Done (t=0.12s)
creating index...
index created!
Resuming from checkpoint at epoch 82, starting at epoch 83
Epoch 84/100, Batch 0/85, seg_loss=0.4757, inst_loss=0.0001, total_loss=0.4758
Epoch 84/100, Batch 10/85, seg_loss=0.4964, inst_loss=0.0129, total_loss=0.5222
Epoch 84/100, Batch 20/85, seg_loss=0.4887, inst_loss=0.0112, total_loss=0.5110
Epoch 84/100, Batch 30/85, seg_loss=0.4869, inst_loss=0.0077, total_loss=0.5023
Epoch 84/100, Batch 40/85, seg_loss=0.4848, inst_loss=0.0064, total_loss=0.4976
Epoch 84/100, Batch 50/85, seg_loss=0.4846, inst_loss=0.0089, total_loss=0.5024
Epoch 84/100, Batch 60/85, seg_loss=0.4831, inst_loss=0.0045, total_loss=0.4921
Epoch 84/100, Batch 70/85, seg_loss=0.4791, inst_loss=0.0039, total_loss=0.4868
Epoch 84/100, Batch 80/85, seg_loss=0.4822, inst_loss=0.0068, total_loss=0.4958
Epoch 84/100, Average Loss: 0.5012
loading annotations i