In [1]:
import sys
sys.path.append("..")
from pathlib import Path
from tqdm import tqdm

import torch
from transformers import (SamModel, SamProcessor)
from mobile_sam import sam_model_registry
from utils.predictor import SamPredictor

from datasets import SA1B_Dataset
from utils import *

%load_ext autoreload
%autoreload 2

ImportError: cannot import name 'ResizeLongestSide' from 'utils' (unknown location)

In [None]:
# Config
DATA_DIR = Path('../Datasets/')
GPU = 0

DEVICE = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu")

In [None]:
# Dataset
dataset = SA1B_Dataset(root=DATA_DIR.joinpath('SA_1B/images/'), features=None, split='sa_000009')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
len(dataset)

In [None]:
for j, (i, l, n) in enumerate(dataloader):
    print(i.shape, l.shape, n)
    if j > 10:
        break

In [None]:
print(len(l.unique()))
plt.imshow(l[0])
plt.show()

In [None]:
# Teacher
teacher = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE).eval()
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

In [None]:
# Student
model_type = "vit_t"
sam_checkpoint = "bin/mobile_sam.pt"

model = sam_model_registry[model_type](checkpoint=None).to(DEVICE).train()
student = SamPredictor(model)

In [None]:
with torch.no_grad():
    inputs = processor(i, input_points=None, return_tensors="pt").to(DEVICE)
    t_features = teacher.get_image_embeddings(inputs["pixel_values"]).cpu()

In [None]:
# mse
mse = torch.nn.MSELoss()
mse(t_features, l)

In [None]:
l

In [None]:
t_features

In [None]:
class Distiller():
    def __init__(self, teacher, student, processor, dataloader, optimizer, device):
        self.teacher = teacher
        self.student = student
        self.processor = processor
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.device = device

    def get_distillation_loss(self, img):
        student.set_image(img[0].permute((2,0,1)))
        s_features = student.features

        with torch.no_grad():
            inputs = processor(img, input_points=None, return_tensors="pt").to(DEVICE)
            t_features = teacher.get_image_embeddings(inputs["pixel_values"])

        return torch.nn.functional.mse_loss(s_features, t_features)

    def distill(self):
        t = tqdm(dataloader, desc='Distillation:')
        for img, _, _ in t:
            self.optimizer.zero_grad()
            loss = self.get_distillation_loss(img)
            loss.backward()
            self.optimizer.step()
            t.set_postfix({'Loss': loss.item()})

In [None]:
optimizer = torch.optim.Adam(student.model.parameters(), lr=1e-3)
distiller = Distiller(teacher, student, processor, dataloader, optimizer, DEVICE)

In [None]:
distiller.distill()

In [None]:
torch.save(distiller.student.model.state_dict(), 'bin/distilled_mobile_sam.pt')

In [None]:
import csv

In [None]:
teacher_features_ids = csv.reader(open(Path('results/feature_ids.csv'), 'r'))
l = list(teacher_features_ids)
#l = [i[0] for i in l]
len(l)

In [None]:
l[:5]

In [None]:
f = torch.load('results/teacher_features.pt')

In [None]:
len(f)

## Decoder Distillation

- prompt a random point belonging to an instance
- get the corresponding mask and mask size
- use saved SAM features
- freeze MobileSAM backbone
- prompt SAM and MobileSAM and collect output masks (3 masks?)
- compute dice and focal loss (20:1)
- weight loss based on mask size
- OBTAIN LOGITS FROM (MOBILESAM, SAM) !!! (return_logits=True, binarize=False)

In [None]:
# get_prompt()
# get_instance_label()
# get_mask_size()
# size coefficient = 1 - (mask_size / image_size)
# get_output(SAM, saved_features, prompt)
# get_output(MobileSAM, saved_features, prompt) 
# dice_loss()
# focal_loss()
# loss = (20 * dice_loss() + focal_loss()) * size_coefficient

In [2]:
from distill import *

In [None]:
DATA_DIR = Path('../Datasets/')
SPLIT = 'sa_000020'
GPU = 2
DEVICE = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 8
SHUFFLE = True
LOAD_FEATURES = True
FEATURES = 'results/teacher_features.pt' if LOAD_FEATURES else None

EPOCHS = 16
LR = 1e-3
OPTIM = 'adamw'
WD = 1e-5
LOSS_WEIGHTS = [0,0,1,0] # 20 focal, 1 dice, 0 bce, 0 size

MODE = 'decoder' # encoder, decoder, save_features
PRETRAINED = True if MODE == 'decoder' else False

dataset = SA1B_Dataset(root=DATA_DIR.joinpath('SA_1B/images/'), split=SPLIT,  features=FEATURES, labels=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=SHUFFLE, num_workers=16, pin_memory=True)

teacher = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
teacher.eval()
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

model_type = "vit_t"
sam_checkpoint = "bin/mobile_sam.pt" if PRETRAINED else None

model = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(DEVICE)
model.eval()
for m in model.image_encoder.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        m.eval()
        m.weight.requires_grad_(False)
        m.bias.requires_grad_(False)
student = SamPredictor(model)

if MODE == 'encoder':
    DISTILLER = EncDistiller
    params = student.model.image_encoder.parameters()
else:
    DISTILLER = DecDistiller
    params = student.model.mask_decoder.parameters()

if OPTIM == 'adamw':
    optimizer = torch.optim.AdamW(params, lr=LR, weight_decay=WD)
elif OPTIM == 'adam':
    optimizer = torch.optim.Adam(params, lr=LR)

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

distiller = DISTILLER(teacher, student, processor, dataloader, optimizer, scheduler, loss_weights=LOSS_WEIGHTS, device=DEVICE)

if MODE == 'save_features':
    distiller.save_teacher_features(Path('results/teacher_features.pt'))
else:
    distiller.distill(epochs=EPOCHS, accumulate=BATCH_SIZE, use_saved_features=LOAD_FEATURES, name=MODE)

## Adapters

https://github.com/tianrun-chen/SAM-Adapter-PyTorch/blob/60bd2770b1c4fcb38d98822cae9a17781601ff9e/models/mmseg/models/sam/image_encoder.py#L263