In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import os
from PIL import Image
import numpy as np 
import cv2
import json
import sys
import tqdm
sys.path.append('/root/jango_ws/src/grasp/src/')
from sam import get_all_masks

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [4]:
test_config = {
    "sam_checkpoint" : "/root/datasets/checkpoints/sam2.1_hiera_large.pt",
    # "sam_checkpoint" : "sam2_model_epoch_1.pth",
    "sam_model_cfg" : "configs/sam2.1/sam2.1_hiera_l.yaml"
}

sam = build_sam2(test_config["sam_model_cfg"], test_config["sam_checkpoint"])
all_masks_predictor = SAM2AutomaticMaskGenerator(
            sam, 
            stability_score_thresh=0.90,
            points_per_side=128
        )

In [5]:
imgs = {p: cv2.cvtColor(cv2.imread(os.path.join('../data/imgs', p)), cv2.COLOR_BGR2RGB) for p in os.listdir("../data/imgs")}
labels = {d['image_path'].split('/')[-1]: d['labels'] for d in [json.load(open(os.path.join('../data/labels', p))) for p in os.listdir("../data/labels")]}

In [6]:
sorted_imgs = [imgs[key] for key in sorted(imgs.keys())]
sorted_labels = [labels[key] for key in sorted(labels.keys())]

In [None]:
optimizer = optim.Adam(all_masks_predictor.predictor.model.parameters(), lr=0.01)
criterion = nn.BCELoss()

In [None]:
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(num_epochs):
    epoch_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    progress_bar = tqdm.tqdm(
        zip(sorted_imgs, sorted_labels),
        total=len(sorted_imgs),
        desc=f"Epoch [{epoch + 1}/{num_epochs}]"
    )

    for img, mask_labels in progress_bar:
        # Forward pass
        optimizer.zero_grad()
        outputs = all_masks_predictor.generate(img, False)  # Assuming all_masks_predictor can process batches
        outputs = [o["shadow_preds"] for o in outputs]  # Adjust dimensions if necessary

        # Compute loss
        loss = criterion(torch.stack(outputs), torch.tensor(mask_labels).to(torch.float32).cuda())
        epoch_loss += loss.item()

        # Calculate accuracy
        predictions = (torch.stack(outputs) > 0.5).to(torch.float32)
        correct_predictions += (predictions == torch.tensor(mask_labels).to(torch.float32).cuda()).sum().item()
        total_predictions += predictions.numel()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update tqdm bar
        progress_bar.set_postfix({"Loss": loss.item(), "Accuracy": correct_predictions / total_predictions})

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss}, Accuracy: {correct_predictions / total_predictions}")
    torch.save(all_masks_predictor.predictor.state_dict(), f"sam2_model_epoch_{epoch + 1}.pth")

Epoch [1/5]:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch [1/5]:  50%|█████     | 10/20 [06:49<06:43, 40.40s/it, Loss=0.723, Accuracy=0.328]