## RaFFM for Segment Anything (SAM)

In [1]:
from transformers import SamModel, SamProcessor, SamVisionConfig
from raffm import RaFFM


  from .autonotebook import tqdm as notebook_tqdm


## Load SAM via SamModel use transformers models

In [2]:
model = SamModel.from_pretrained("facebook/sam-vit-huge")


## Convert original SAM to scalable RaFFM model

In [3]:
## Subnetwork space
elastic_config = {
    "atten_out_space": [512, 768, 1280], # keep attention layer fixed
    "inter_hidden_space": [512, 1024, 1280, 2048],
    "residual_hidden_space": [1024, 1280, 2048],
}
raffm_model = RaFFM(model.to("cpu"),elastic_config=elastic_config)
print("Original FM number of parameters:",raffm_model.total_params)



Original FM number of parameters: 635.177828


### Randomly sample a subnetwork from original model

In [13]:
#Random sample a scaled FM
submodel, params, config = raffm_model.random_resource_aware_model()
print("subnetwork params",params)

subnetwork params 256.379748


In [24]:
import torch

torch.save(submodel.state_dict(),"./ckpt/ckpts.pt")
torch.save(config,"./ckpt/arc_config.pt")

AttributeError: 'SamModel' object has no attribute 'save_state_dict'

In [18]:
sub_config = torch.load("./ckpt/arc_config.pt")
loaded_net,_ = raffm_model.resource_aware_model(sub_config)

In [21]:
submodel

SamModel(
  (shared_image_embedding): SamPositionalEmbedding()
  (vision_encoder): SamVisionEncoder(
    (patch_embed): SamPatchEmbeddings(
      (projection): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (layers): ModuleList(
      (0): SamVisionLayer(
        (layer_norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): SamVisionAttention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (layer_norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): SamMLPBlock(
          (lin1): Linear(in_features=1280, out_features=2048, bias=True)
          (lin2): Linear(in_features=2048, out_features=1280, bias=True)
          (act): GELUActivation()
        )
      )
      (1-2): 2 x SamVisionLayer(
        (layer_norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): SamVisionAttention(
      

In [23]:
loaded_net.load_state_dict(torch.load("./ckpt/pytorch_model.bin"))

TypeError: Expected state_dict to be dict-like, got <class 'str'>.

### Sample a new subnetwork

In [5]:
#Random sample a scaled FM
submodel, params, config = raffm_model.random_resource_aware_model()
print("subnetwork params",params)

subnetwork params 224.267108


## Test inference the subnetwork

In [6]:
import torch
from PIL import Image
import requests


processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]  # 2D location of a window in the image

inputs = processor(raw_image, input_points=input_points, return_tensors="pt")


In [7]:
with torch.no_grad():
    outputs = submodel(**inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores

In [8]:
scores

tensor([[[0.4543, 0.5698, 0.5852]]])