[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)

# Fine-Tune Segment Anything 2.1 (SAM-2.1)

---

[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/facebookresearch/sam2)

Segment Anything Model is a computer vision from Meta AI that can "cut out" any object, in any image, with a single click.

In September 2024, Meta Research released SAM-2.1 SAM-2.1 is the latest model in the Segment Anything model series. When evaluated against the Segment Anything V test set, the MOSE validation set, and the LVOSv2 dataset, all SAM-2.1 model sizes perform better than SAM-2.

SAM-2.1 was released with training instructions that you can use to fine-tune SAM-2.1 for a specific use case. This is ideal if you want to train SAM-2.1 to segment objects in a specific domain at which the base model struggles.

Here is an example of results from SAM-2, sourced from the Meta SAM-2 GitHub repository:

![segment anything model](https://github.com/facebookresearch/sam2/raw/main/assets/sa_v_dataset.jpg?raw=true)

We recommend that you follow along in this notebook while reading the blog post on [SAM-2.1 fine-tuning](https://blog.roboflow.com/sam-2-1-fine-tuning).

## Pro Tip: Use GPU Acceleration

If you are running this notebook in Google Colab, navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`. This will ensure your notebook uses a GPU, which will significantly speed up model training times.

Meta recommends training SAM-2.1 on an A100. Thus, if possible, select an A100 GPU in Google Colab for use in training the model.

## Steps in this Tutorial

In this tutorial, we are going to cover:

- **Before you start** - Make sure you have access to the GPU
- Download SAM-2.1
- Download Example Data
- Load Model
- Automated Mask Generation

Without further ado, let's get started!

## Download SAM-2.1 and Data

Below, we download SAM-2.1 from GitHub, then download a dataset for use in training. You will need a dataset structured in the correct format for SAM-2.1.

Roboflow supports exporting segmentation datasets to the SAM-2.1 format, ideal for use in this guide. You can upload segmentation datasets in the COCO JSON Segmentation format then convert them to SAM-2.1 for use in this guide.

[Learn how to label a dataset in Roboflow](https://blog.roboflow.com/getting-started-with-roboflow/)

[Learn how to export data from Roboflow for training](https://docs.roboflow.com/datasets/exporting-data).

![Export as SAM-2 data](https://media.roboflow.com/sam2export.png)

We then download a SAM-2.1 training YAML file which we will use to configure our model training job.

Finally, we install SAM-2.1 and download the model checkpoints.

Replace the below code with the code to export your dataset. You can also use the same code above to fine-tune our car parts dataset. _Note: If you use the car parts dataset pre-filled below, you will still need to add a [Roboflow API key](https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key)._

**Important ⚠️: You must generate a dataset with images stretched to 1024x1024 for training. This is because our training configuration is set to use images of this resolution.**

In [1]:
!pip install roboflow
import os
from roboflow import Roboflow
rf = Roboflow(api_key="U37Y1VWgripWggOxgiCE")
project = rf.workspace("project-zero").project("aerial_river_plastic_wastes")
version = project.version(9)
dataset = version.download("sam2")

# rename dataset.location to "data"
os.rename(dataset.location, "/content/data")

Collecting roboflow
  Downloading roboflow-1.1.50-py3-none-any.whl.metadata (9.7 kB)
Collecting idna==3.7 (from roboflow)
  Downloading idna-3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting python-dotenv (from roboflow)
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting filetype (from roboflow)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading roboflow-1.1.50-py3-none-any.whl (81 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.5/81.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading idna-3.7-py3-none-any.whl (66 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)
Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Installing collected packages: filetype, python-dotenv, idna, roboflow
  Attempting uninstall: idna
    Found existing installation: idna 3.10
    

Downloading Dataset Version Zip in Aerial_River_Plastic_Wastes-9 to sam2:: 100%|██████████| 134879/134879 [00:08<00:00, 15795.45it/s]





Extracting Dataset Version Zip to Aerial_River_Plastic_Wastes-9 in sam2:: 100%|██████████| 4931/4931 [00:01<00:00, 3128.24it/s]


In [2]:
!git clone https://github.com/facebookresearch/sam2.git

Cloning into 'sam2'...
remote: Enumerating objects: 1052, done.[K
remote: Counting objects: 100% (460/460), done.[K
remote: Compressing objects: 100% (205/205), done.[K
remote: Total 1052 (delta 277), reused 290 (delta 255), pack-reused 592 (from 1)[K
Receiving objects: 100% (1052/1052), 121.74 MiB | 1.36 MiB/s, done.
Resolving deltas: 100% (378/378), done.


In [3]:
!wget -O /content/sam2/sam2/configs/train.yaml 'https://drive.usercontent.google.com/download?id=11cmbxPPsYqFyWq87tmLgBAQ6OZgEhPG3'

--2024-12-14 12:13:11--  https://drive.usercontent.google.com/download?id=11cmbxPPsYqFyWq87tmLgBAQ6OZgEhPG3
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 74.125.143.132, 2a00:1450:4013:c03::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|74.125.143.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11055 (11K) [application/octet-stream]
Saving to: ‘/content/sam2/sam2/configs/train.yaml’


2024-12-14 12:13:13 (85.6 MB/s) - ‘/content/sam2/sam2/configs/train.yaml’ saved [11055/11055]



In [4]:
%cd ./sam2/

/content/sam2


Next, we are going to install SAM-2.

The SAM-2 installation process may take several minutes.

In [5]:
!pip install -e .[dev] -q

  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.6/74.6 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m52.

In [6]:
!cd ./checkpoints && ./download_ckpts.sh

Downloading sam2.1_hiera_tiny.pt checkpoint...
--2024-12-14 12:17:46--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.219.33, 13.227.219.10, 13.227.219.59, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.219.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 156008466 (149M) [application/vnd.snesdev-page-table]
Saving to: ‘sam2.1_hiera_tiny.pt’


2024-12-14 12:17:47 (229 MB/s) - ‘sam2.1_hiera_tiny.pt’ saved [156008466/156008466]

Downloading sam2.1_hiera_small.pt checkpoint...
--2024-12-14 12:17:47--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.219.33, 13.227.219.10, 13.227.219.59, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.219.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 184

## Modify Dataset File Names

SAM-2.1 requires dataset file names to be in a particular format. Run the code snippet below to format your dataset file names as required.

In [7]:
# Script to rename roboflow filenames to something SAM 2.1 compatible.
# Maybe it is possible to remove this step tweaking sam2/sam2/configs/train.yaml.
import os
import re

FOLDER = "/content/data/train"

for filename in os.listdir(FOLDER):
    # Replace all except last dot with underscore
    new_filename = filename.replace(".", "_", filename.count(".") - 1)
    if not re.search(r"_\d+\.\w+$", new_filename):
        # Add an int to the end of base name
        new_filename = new_filename.replace(".", "_1.")
    os.rename(os.path.join(FOLDER, filename), os.path.join(FOLDER, new_filename))

## Start Training

You can now start training a SAM-2.1 model. The amount of time it will take to train the model will vary depending on the GPU you are using and the number of images in your dataset.

For the car part dataset of 38 images, training on an A100 GPU takes ~15 minutes.

In [None]:
!python training/train.py -c 'configs/train.yaml' --use-cluster 0 --num-gpus 1

###################### Train App Config ####################
scratch:
  resolution: 1024
  train_batch_size: 1
  num_train_workers: 10
  num_frames: 1
  max_num_objects: 3
  base_lr: 5.0e-06
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 40
dataset:
  img_folder: /content/data/train
  gt_folder: /content/data/train
  multiplier: 2
vos:
  train_transforms:
  - _target_: training.dataset.transforms.ComposeAPI
    transforms:
    - _target_: training.dataset.transforms.RandomHorizontalFlip
      consistent_transform: true
    - _target_: training.dataset.transforms.RandomAffine
      degrees: 25
      shear: 20
      image_interpolation: bilinear
      consistent_transform: true
    - _target_: training.dataset.transforms.RandomResizeAPI
      sizes: ${scratch.resolution}
      square: true
      consistent_transform: true
    - _target_: training.dataset.transforms.ColorJitter
      consistent_transform: true
      brightness: 0.1
      contrast: 0.03
      saturation: 0.03
   

You can visualize the model training graphs with Tensorboard:

In [None]:
%load_ext tensorboard
%tensorboard --bind_all --logdir ./sam2_logs/

## Visualize Model Results

With a trained model ready, we can test the model on an image from our test set.

To assist with visualizing model predictions, we are going to use Roboflow supervision, an open source computer vision Python package with utilities for working with vision model outputs


In [None]:
!pip install supervision -q

### Load SAM-2.1

In [None]:
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import supervision as sv
import os
import random
from PIL import Image
import numpy as np

# use bfloat16 for the entire notebook
# from Meta notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

checkpoint = "/content/sam2/sam2_logs/configs/train.yaml/checkpoints/checkpoint.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
sam2 = build_sam2(model_cfg, checkpoint, device="cuda")
mask_generator = SAM2AutomaticMaskGenerator(sam2)

checkpoint_base = "/content/sam2/checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg_base = "configs/sam2.1/sam2.1_hiera_b+.yaml"
sam2_base = build_sam2(model_cfg_base, checkpoint_base, device="cuda")
mask_generator_base = SAM2AutomaticMaskGenerator(sam2_base)

### Run Inference on an Image in Automatic Mask Generation Mode

In [None]:
validation_set = os.listdir("/content/data/valid")

# choose random with .json extension
image = random.choice([img for img in validation_set if img.endswith(".jpg")])
image = os.path.join("/content/data/valid", image)
opened_image = np.array(Image.open(image).convert("RGB"))
result = mask_generator.generate(opened_image)

detections = sv.Detections.from_sam(sam_result=result)

mask_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.INDEX)
annotated_image = opened_image.copy()
annotated_image = mask_annotator.annotate(annotated_image, detections=detections)

base_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.INDEX)

base_result = mask_generator_base.generate(opened_image)
base_detections = sv.Detections.from_sam(sam_result=base_result)
base_annotated_image = opened_image.copy()
base_annotated_image = base_annotator.annotate(base_annotated_image, detections=base_detections)

sv.plot_images_grid(images=[annotated_image, base_annotated_image], titles=["Fine-Tuned SAM-2.1", "Base SAM-2.1"], grid_size=(1, 2))

## Exporting Model
