# LoRA Baseline for MobileNet

In [1]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2

from pathlib import Path
import yaml
import torch
from transformers import SamModel, SamProcessor
from utils.mobile_sam import sam_model_registry
from utils.mobile_sam.predictor import SamPredictor
from utils.datasets import SA1B_Dataset
from utils.utils import *
from utils.distill_utils import *

from minlora import (add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, get_lora_state_dict,
                     merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora, LoRAParametrization)

In [2]:
with open('../config_distillation.yaml', 'r') as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg['DATA_DIR'] = Path('../').joinpath(cfg['DATA_DIR'])
cfg['OUTPUT_DIR'] = Path('../').joinpath(cfg['OUTPUT_DIR'])
cfg['MODEL_DIR'] = Path('../').joinpath(cfg['MODEL_DIR'])
cfg['PROMPT_DIR'] = cfg['OUTPUT_DIR'].joinpath(f"prompts.pkl")
cfg['DEVICE'] = torch.device(f"cuda:{cfg['GPU']}" if torch.cuda.is_available() else "cpu")
cfg['PRETRAINED'] = True if cfg['MODE'] in ['decoder', 'prompt'] else False

# DATASET
dataset = SA1B_Dataset(root=cfg['DATA_DIR'], split=["sa_00000" + str(i) for i in range(cfg['TRAIN_SPLITS'])],
                        features=None, labels=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=cfg['SHUFFLE'], num_workers=cfg['WORKERS'], pin_memory=False)
test_dataset = SA1B_Dataset(root=cfg['DATA_DIR'], split=[cfg['SPLIT']], 
                            features=None, labels=True, max_samples=cfg['MAX_TEST'])
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=cfg['WORKERS'], pin_memory=False)

# MODEL
teacher = SamModel.from_pretrained("facebook/sam-vit-huge").to(cfg['DEVICE'])
teacher.eval()
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
sam_checkpoint = cfg['MODEL_DIR'].joinpath(cfg['CKPT']) if cfg['PRETRAINED'] else None
model = sam_model_registry["vit_t"](checkpoint=sam_checkpoint, add_prompt=cfg['ADD_PROMPT']) #.to(cfg['DEVICE'])
model.eval()
for m in model.image_encoder.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        m.eval()
        m.weight.requires_grad_(False)
        m.bias.requires_grad_(False)
student = SamPredictor(model)

In [3]:
# Step 1: Add LoRA to the model
add_lora(student.model.mask_decoder)

# Step 2: Collect the parameters, pass them to the optimizer
parameters = [{"params": list(get_lora_params(student.model.mask_decoder))}]
optimizer = torch.optim.AdamW(parameters, lr=1e-3)

# Step 3: Train the model
# ...

# Step 4: export the LoRA parameters
lora_state_dict = get_lora_state_dict(model)
print(lora_state_dict.keys())

dict_keys(['mask_decoder.transformer.layers.0.self_attn.q_proj.parametrizations.weight.0.lora_A', 'mask_decoder.transformer.layers.0.self_attn.q_proj.parametrizations.weight.0.lora_B', 'mask_decoder.transformer.layers.0.self_attn.k_proj.parametrizations.weight.0.lora_A', 'mask_decoder.transformer.layers.0.self_attn.k_proj.parametrizations.weight.0.lora_B', 'mask_decoder.transformer.layers.0.self_attn.v_proj.parametrizations.weight.0.lora_A', 'mask_decoder.transformer.layers.0.self_attn.v_proj.parametrizations.weight.0.lora_B', 'mask_decoder.transformer.layers.0.self_attn.out_proj.parametrizations.weight.0.lora_A', 'mask_decoder.transformer.layers.0.self_attn.out_proj.parametrizations.weight.0.lora_B', 'mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.parametrizations.weight.0.lora_A', 'mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.parametrizations.weight.0.lora_B', 'mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.parametrization

In [4]:
sum = 0
for i in lora_state_dict.keys():
    if 'decoder' in i:
        # print(i, lora_state_dict[i].shape[0] * lora_state_dict[i].shape[1])
        sum += lora_state_dict[i].shape[0] * lora_state_dict[i].shape[1]
print(len(lora_state_dict), sum)

94 110096


In [5]:
for n, m in student.model.mask_decoder.named_modules():
    print(n)


transformer
transformer.layers
transformer.layers.0
transformer.layers.0.self_attn
transformer.layers.0.self_attn.q_proj
transformer.layers.0.self_attn.q_proj.parametrizations
transformer.layers.0.self_attn.q_proj.parametrizations.weight
transformer.layers.0.self_attn.q_proj.parametrizations.weight.0
transformer.layers.0.self_attn.k_proj
transformer.layers.0.self_attn.k_proj.parametrizations
transformer.layers.0.self_attn.k_proj.parametrizations.weight
transformer.layers.0.self_attn.k_proj.parametrizations.weight.0
transformer.layers.0.self_attn.v_proj
transformer.layers.0.self_attn.v_proj.parametrizations
transformer.layers.0.self_attn.v_proj.parametrizations.weight
transformer.layers.0.self_attn.v_proj.parametrizations.weight.0
transformer.layers.0.self_attn.out_proj
transformer.layers.0.self_attn.out_proj.parametrizations
transformer.layers.0.self_attn.out_proj.parametrizations.weight
transformer.layers.0.self_attn.out_proj.parametrizations.weight.0
transformer.layers.0.norm1
trans

In [6]:
for n, m in student.model.mask_decoder.named_modules():
    if 'parametrization' in n:
        m.train()
        #print(n)
    else:
        m.eval()
        print(n)


transformer
transformer.layers
transformer.layers.0
transformer.layers.0.self_attn
transformer.layers.0.self_attn.q_proj
transformer.layers.0.self_attn.k_proj
transformer.layers.0.self_attn.v_proj
transformer.layers.0.self_attn.out_proj
transformer.layers.0.norm1
transformer.layers.0.cross_attn_token_to_image
transformer.layers.0.cross_attn_token_to_image.q_proj
transformer.layers.0.cross_attn_token_to_image.k_proj
transformer.layers.0.cross_attn_token_to_image.v_proj
transformer.layers.0.cross_attn_token_to_image.out_proj
transformer.layers.0.norm2
transformer.layers.0.mlp
transformer.layers.0.mlp.lin1
transformer.layers.0.mlp.lin2
transformer.layers.0.mlp.act
transformer.layers.0.norm3
transformer.layers.0.norm4
transformer.layers.0.cross_attn_image_to_token
transformer.layers.0.cross_attn_image_to_token.q_proj
transformer.layers.0.cross_attn_image_to_token.k_proj
transformer.layers.0.cross_attn_image_to_token.v_proj
transformer.layers.0.cross_attn_image_to_token.out_proj
transforme

In [46]:
o = torch.load('../test.pt')

In [45]:
l = torch.load('merged.pt')

In [52]:
sum = 0
for i in o.keys():
    test = bool((o[i].cpu() == l[i].cpu()).all())
    if not test:
        print(i, test)
        sum += o[i].shape[0] * o[i].shape[1]
print(sum)

mask_decoder.transformer.layers.0.self_attn.q_proj.weight False
mask_decoder.transformer.layers.0.self_attn.k_proj.weight False
mask_decoder.transformer.layers.0.self_attn.v_proj.weight False
mask_decoder.transformer.layers.0.self_attn.out_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight False
mask_decoder.transformer.layers.0.mlp.lin1.weight False
mask_decoder.transformer.layers.0.mlp.lin2.weight False
mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight False
mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj

In [37]:
for i in l.keys():
    print(i)

image_encoder.patch_embed.seq.0.c.weight
image_encoder.patch_embed.seq.0.bn.weight
image_encoder.patch_embed.seq.0.bn.bias
image_encoder.patch_embed.seq.0.bn.running_mean
image_encoder.patch_embed.seq.0.bn.running_var
image_encoder.patch_embed.seq.0.bn.num_batches_tracked
image_encoder.patch_embed.seq.2.c.weight
image_encoder.patch_embed.seq.2.bn.weight
image_encoder.patch_embed.seq.2.bn.bias
image_encoder.patch_embed.seq.2.bn.running_mean
image_encoder.patch_embed.seq.2.bn.running_var
image_encoder.patch_embed.seq.2.bn.num_batches_tracked
image_encoder.layers.0.blocks.0.conv1.c.weight
image_encoder.layers.0.blocks.0.conv1.bn.weight
image_encoder.layers.0.blocks.0.conv1.bn.bias
image_encoder.layers.0.blocks.0.conv1.bn.running_mean
image_encoder.layers.0.blocks.0.conv1.bn.running_var
image_encoder.layers.0.blocks.0.conv1.bn.num_batches_tracked
image_encoder.layers.0.blocks.0.conv2.c.weight
image_encoder.layers.0.blocks.0.conv2.bn.weight
image_encoder.layers.0.blocks.0.conv2.bn.bias
imag

In [50]:
l['prompt_encoder.point_embeddings.4.weight'].shape

torch.Size([1, 256])

In [39]:
student.model.load_state_dict(l)

<All keys matched successfully>

In [41]:
merge_lora(student.model)

In [43]:
torch.save(student.model.state_dict(), 'merged.pt')

In [51]:
(3.82 - 3.75) / 3.82

0.018324607329842892