In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import os
from PIL import Image
import numpy as np 
import cv2
import json
import sys
import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
sys.path.append('/root/jango_ws/src/grasp/src/')

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [4]:
test_config = {
    "sam_checkpoint" : "../checkpoints/sam2.1_hiera_large.pt",
    # "sam_checkpoint" : "sam2_model_epoch_1.pth",
    "sam_model_cfg" : "configs/sam2.1/sam2.1_hiera_l.yaml"
}

sam = build_sam2(test_config["sam_model_cfg"], test_config["sam_checkpoint"])
all_masks_predictor = SAM2AutomaticMaskGenerator(
            sam, 
            stability_score_thresh=0.90,
            points_per_side=128
        )

In [5]:
imgs = {p: cv2.cvtColor(cv2.imread(os.path.join('../data/imgs', p)), cv2.COLOR_BGR2RGB) for p in os.listdir("../data/imgs")}
labels = {d['image_path'].split('/')[-1]: d['labels'] for d in [json.load(open(os.path.join('../data/labels', p))) for p in os.listdir("../data/labels")]}

In [6]:
sorted_imgs = [imgs[key] for key in sorted(imgs.keys())]
sorted_labels = [labels[key] for key in sorted(labels.keys())]

In [7]:
optimizer = optim.Adam(all_masks_predictor.predictor.model.parameters(), lr=0.01)
criterion = nn.BCELoss()

In [None]:
img = list(imgs.values())[0]
shape = img.shape
points = np.meshgrid(np.linspace(0, shape[1], 128), np.linspace(0, shape[0], 128))
points = np.stack(points, axis=-1)
points = points.reshape(-1, 2)
points = torch.from_numpy(points.astype(np.float32)).cuda().unsqueeze(1)
len(points)

16384

In [12]:
num_points_per_batch = 128
all_masks_predictor.predictor.set_image(img)
all_masks_predictor.predictor.predict(points[:num_points_per_batch], torch.ones((num_points_per_batch,1), dtype=torch.int64))[-1].shape

(128, 256)

In [13]:
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
img_raw_states = []
progress_bar = tqdm.tqdm(
    zip(sorted_imgs, sorted_labels),
    total=len(sorted_imgs),
    desc=f"Loading mask tokens"
)

for img, mask_labels in progress_bar:
    # Forward pass
    outputs = all_masks_predictor.generate(img, False)
    img_raw_states.append(torch.stack([o['shadow_input_tokens'] for o in outputs]).to(device))

Loading mask tokens:   0%|          | 0/20 [00:00<?, ?it/s]

Loading mask tokens: 100%|██████████| 20/20 [11:44<00:00, 35.25s/it]


In [None]:
unique_labels, label_counts = np.unique(sorted_labels, return_counts=True)
total_samples = len(sorted_labels)
class_weights = torch.tensor([total_samples / (len(unique_labels) * count) for count in label_counts]).cuda()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    all_predictions = []
    all_labels = []

    progress_bar = tqdm.tqdm(
        zip(img_raw_states, sorted_labels),
        total=len(img_raw_states),
        desc=f"Epoch [{epoch + 1}/{num_epochs}]"
    )

    for mask_states, mask_labels in progress_bar:
        # Forward pass
        optimizer.zero_grad()
        preds = all_masks_predictor.predictor.model.sam_mask_decoder.shadow_prediction_head(mask_states)

        # Compute loss
        loss = criterion(preds[:, 0], torch.tensor(mask_labels).to(torch.float32).cuda(), 
                                weight=class_weights[mask_labels.astype(int)])
        epoch_loss += loss.item()

        # Calculate accuracy
        predictions = (preds[:, 0] > 0.5).to(torch.float32)
        correct_predictions += (predictions == torch.tensor(mask_labels).to(torch.float32).cuda()).sum().item()
        total_predictions += predictions.numel()
        all_predictions.extend(predictions)
        all_labels.extend(mask_labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update tqdm bar
        progress_bar.set_postfix({"Loss": loss.item(), "Accuracy": correct_predictions / total_predictions})

    final_accuracy = np.mean(np.array(all_predictions) == np.array(all_labels))
    final_precision = precision_score(all_labels, all_predictions, zero_division=0)
    final_recall = recall_score(all_labels, all_predictions, zero_division=0)
    final_f1 = f1_score(all_labels, all_predictions, zero_division=0)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss}, Accuracy: {correct_predictions / total_predictions}, Precision: {final_precision}, Recall: {final_recall}, F1: {final_f1}")
    torch.save(all_masks_predictor.predictor.model.state_dict(), f"sam2_model_epoch_{epoch + 1}.pth")

Epoch [1/5]:   0%|          | 0/20 [00:00<?, ?it/s, Loss=0.37, Accuracy=0.877] 

Epoch [1/5]: 100%|██████████| 20/20 [00:00<00:00, 493.19it/s, Loss=0.281, Accuracy=0.897]


Epoch [1/5], Loss: 5.153079580515623, Accuracy: 0.897


Epoch [2/5]: 100%|██████████| 20/20 [00:00<00:00, 449.23it/s, Loss=0.22, Accuracy=0.936]


Epoch [2/5], Loss: 2.8212541369721293, Accuracy: 0.936


Epoch [3/5]: 100%|██████████| 20/20 [00:00<00:00, 510.17it/s, Loss=0.155, Accuracy=0.945]


Epoch [3/5], Loss: 2.436914478428662, Accuracy: 0.945


Epoch [4/5]: 100%|██████████| 20/20 [00:00<00:00, 460.96it/s, Loss=0.119, Accuracy=0.955]


Epoch [4/5], Loss: 1.8468625340610743, Accuracy: 0.955


Epoch [5/5]: 100%|██████████| 20/20 [00:00<00:00, 489.69it/s, Loss=0.1, Accuracy=0.968]


Epoch [5/5], Loss: 1.6575458988081664, Accuracy: 0.968
