Application for fine-tuned SAM

In [9]:
#backend
import torch
from transformers import SamModel, SamProcessor
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
checkpoint_path = "/vol/data/models/custom5e-05 lr,1e-04 wd,2 bs, diceCE loss, grayscale, 24-02-23_17.35.30_24-02-23_17.35.30"
model.load_state_dict(torch.load(checkpoint_path +".pt"))

<All keys matched successfully>

In [10]:
def inference(img, pixel, prompt_type):
    model.eval()
    with torch.no_grad():
        if (prompt_type=="points"):
            inputs = processor(img, input_points= [[pixel]], return_tensors="pt").to(device)
        else:
            inputs = processor(img, input_boxes=[[pixel]], return_tensors="pt").to(device)
        outputs= model(**inputs, multimask_output=False)
        masks = F.interpolate(outputs.pred_masks.squeeze(2), (1024,1024), mode="bilinear", align_corners=False)
        masks = masks[..., : inputs["reshaped_input_sizes"][0,0], : inputs["reshaped_input_sizes"][0,1]]
        masks = F.interpolate(masks, (inputs["original_sizes"][0,0],inputs["original_sizes"][0,1]), mode="bilinear", align_corners=False)
        masks = torch.sigmoid(masks).cpu().squeeze().numpy()
        binary_masks = (masks > 0.5).astype(np.uint8)
    return binary_masks

In [11]:
#application points
import gradio as gr
import numpy as np

with gr.Blocks() as demo:
    with gr.Row():
        input_img = gr.Image(label="Input")
        img_output = gr.AnnotatedImage(
            color_map={"red": "#ff0000"}
        )

    def get_select_coords(img, evt: gr.SelectData):
        pixel = evt.index
        mask = inference(img, pixel, "points")
        return (img, [(mask, "red")])
    
    input_img.select(get_select_coords, input_img, img_output)

if __name__ == "__main__":
    demo.launch(share=True)


Running on local URL:  http://127.0.0.1:7864
Running on public URL: https://79716390a443ca3b95.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


In [12]:
#application bboxes
import gradio as gr
import numpy as np

previous_point = None
with gr.Blocks() as demo:
    with gr.Row():
        input_img = gr.Image(label="Input")
        img_output = gr.AnnotatedImage(
            color_map={"mask": "#ff0000", "box": "#00ff00", "first corner": "#00ff00"}
        )

    def get_select_coords(img, evt: gr.SelectData):
        global previous_point
        if previous_point:
            pixel = previous_point+evt.index
            mask = inference(img, pixel, "bboxes")
            previous_point = None
            return (img, [(pixel,"box"),(mask, "mask")])
        else:
            previous_point = evt.index
            mask = np.zeros(img.shape[:2])
            mask[previous_point[1]-1:previous_point[1]+2, previous_point[0]-1:previous_point[0]+2] = 1
            return (img, [(mask, "first corner")])
    
    input_img.select(get_select_coords, input_img, img_output)

if __name__ == "__main__":
    demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7865
Running on public URL: https://11396e4f3f1e6be80e.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Traceback (most recent call last):
  File "/vol/data/miniconda3/envs/DILab3.10/lib/python3.10/site-packages/gradio/queueing.py", line 495, in call_prediction
    output = await route_utils.call_process_api(
  File "/vol/data/miniconda3/envs/DILab3.10/lib/python3.10/site-packages/gradio/route_utils.py", line 230, in call_process_api
    output = await app.get_blocks().process_api(
  File "/vol/data/miniconda3/envs/DILab3.10/lib/python3.10/site-packages/gradio/blocks.py", line 1590, in process_api
    result = await self.call_function(
  File "/vol/data/miniconda3/envs/DILab3.10/lib/python3.10/site-packages/gradio/blocks.py", line 1176, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/anyio/to_thread.py", line 49, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2103, in run_sync_in_worker_thread
 

Comparison of SAM and fine-tuned SAM

In [6]:
#backend
import torch
from transformers import SamModel, SamProcessor
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
checkpoint_path = "/vol/data/models/custom5e-05 lr,1e-04 wd,2 bs, diceCE loss, grayscale, 24-02-23_17.35.30_24-02-23_17.35.30"
model.load_state_dict(torch.load(checkpoint_path +".pt"))

sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)

In [7]:
def inference(img, pixel, prompt_type, fine_tuned):
    model.eval()
    with torch.no_grad():
        if (prompt_type=="points"):
            inputs = processor(img, input_points= [[pixel]], return_tensors="pt").to(device)
        else:
            inputs = processor(img, input_boxes=[[pixel]], return_tensors="pt").to(device)
        if (fine_tuned):
            outputs= model(**inputs, multimask_output=False)
        else:
            outputs= sam_model(**inputs, multimask_output=False)
        masks = F.interpolate(outputs.pred_masks.squeeze(2), (1024,1024), mode="bilinear", align_corners=False)
        masks = masks[..., : inputs["reshaped_input_sizes"][0,0], : inputs["reshaped_input_sizes"][0,1]]
        masks = F.interpolate(masks, (inputs["original_sizes"][0,0],inputs["original_sizes"][0,1]), mode="bilinear", align_corners=False)
        masks = torch.sigmoid(masks).cpu().squeeze().numpy()
        binary_masks = (masks > 0.5).astype(np.uint8)
    return binary_masks

In [8]:
#application bboxes
import gradio as gr
import numpy as np

previous_point = None
with gr.Blocks() as demo:
    with gr.Row():
        input_img = gr.Image(label="Input")
        img_output = gr.AnnotatedImage(
            label="fine-tuned SAM",
            color_map={"mask": "#ff0000", "box": "#00ff00", "first corner": "#00ff00"}
        )
        img_sam = gr.AnnotatedImage(
            label ="default SAM",
            color_map={"mask": "#ff0000", "box": "#00ff00", "first corner": "#00ff00"}
        )

    def get_select_coords(img, evt: gr.SelectData):
        global previous_point
        if previous_point:
            pixel = previous_point+evt.index
            mask = inference(img, pixel, "bboxes", True)
            sam_mask = inference(img, pixel, "bboxes", False)
            previous_point = None
            return ((img, [(pixel,"box"),(mask, "mask")]), (img, [(pixel,"box"),(sam_mask, "mask")]))
        else:
            previous_point = evt.index
            mask = np.zeros(img.shape[:2])
            mask[previous_point[1]-1:previous_point[1]+2, previous_point[0]-1:previous_point[0]+2] = 1
            return ((img, [(mask, "first corner")]),(img, [(mask, "first corner")]))
    
    input_img.select(get_select_coords, input_img, [img_output, img_sam])

if __name__ == "__main__":
    demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7863
Running on public URL: https://e58955382a71b0ee72.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
