## RaFFM for Segment Anything (SAM)

In [1]:
from transformers import SamModel, SamProcessor, SamVisionConfig
from raffm import RaFFM


  from .autonotebook import tqdm as notebook_tqdm


## Load SAM via SamModel use transformers models

In [2]:
model = SamModel.from_pretrained("facebook/sam-vit-huge")


## Convert original SAM to scalable RaFFM model

In [3]:
## Subnetwork space
elastic_config = {
    "atten_out_space": [512, 768, 1280], # keep attention layer fixed
    "inter_hidden_space": [512, 1024, 1280, 2048],
    "out_hidden_space": [1024, 1280, 2048],
}
raffm_model = RaFFM(model.to("cpu"),elastic_config=elastic_config)
print("Original FM number of parameters:",raffm_model.total_params)



Original FM number of parameters: 635.177828


### Randomly sample a subnetwork from original model

In [4]:
#Random sample a scaled FM
submodel, params, config = raffm_model.random_resource_aware_model()
print("subnetwork params",params)

subnetwork params 235.408228


### Sample a new subnetwork

In [5]:
#Random sample a scaled FM
submodel, params, config = raffm_model.random_resource_aware_model()
print("subnetwork params",params)

subnetwork params 239.995748


## Test inference the subnetwork

In [6]:
import torch
from PIL import Image
import requests


processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]  # 2D location of a window in the image

inputs = processor(raw_image, input_points=input_points, return_tensors="pt")


In [7]:
with torch.no_grad():
    outputs = submodel(**inputs)

masks = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores

In [8]:
scores

tensor([[[0.4596, 0.5067, 0.5046]]])