# LightlyTrain - Instance Segmentation with DINOv3 EoMT

This notebook demonstrates how to use LightlyTrain for instance segmentation with our
state-of-the-art [EoMT](https://arxiv.org/abs/2503.19108) model built on [DINOv3](https://github.com/facebookresearch/dinov3)
backbones, with our publicly released weights trained on the [COCO](https://arxiv.org/abs/1612.03716)
dataset.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/eomt_instance_segmentation.ipynb)

> **Important**: When running on Google Colab make sure to select a GPU runtime for faster processing. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator.

## Installation

LightlyTrain can be installed directly via `pip`:

In [None]:
!pip install lightly-train

> **Important**: LightlyTrain is officially supported on
> - Linux: CPU or CUDA
> - MacOS: CPU only
> - Windows (experimental): CPU or CUDA
>
> We are planning to support MPS for MacOS.
>
> Check the [installation instructions](https://docs.lightly.ai/train/stable/installation.html) for more details on installation.

## Prediction using LightlyTrain's model weights

### Download an example image

Download an example image for inference with the following command:

In [None]:
!wget -O image.jpg http://images.cocodataset.org/val2017/000000039769.jpg

### Load the model weights

Then load the model weights with LightlyTrain's `load_model` function:

In [None]:
import lightly_train

model = lightly_train.load_model("dinov3/vits16-eomt-inst-coco")

### Predict the instances

Run `model.predict` on the image. The method accepts file paths, URLs, PIL Images, or tensors as input.

In [None]:
prediction = model.predict("image.jpg", threshold=0.8)

### Visualize the results

Visualize the image and predicted instance masks to inspect the segmentation output.

In [None]:
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision.utils import draw_segmentation_masks

image = read_image("image.jpg")
masks = prediction["masks"]
labels = prediction["labels"]
scores = prediction["scores"]
image_with_masks = draw_segmentation_masks(
    image,
    masks=masks,
    alpha=1.0,
)
plt.imshow(image_with_masks.permute(1, 2, 0))
plt.axis("off")
plt.show()

The predicted masks are returned as tensors with shape `(N, height, width)` and coordinates aligned with the input image.

## Train an instance segmentation model

Training your own instance segmentation model is straightforward with LightlyTrain.

### Download dataset

First download a dataset in YOLO segmentation format.

In [None]:
!wget -O coco128-seg.zip https://github.com/ultralytics/assets/releases/download/v0.0.0/coco128-seg.zip && unzip coco128-seg.zip

Then start the training with the `train_instance_segmentation` function. You can specify various training parameters such as the model architecture, number of training steps, batch size, learning rate, and more.

In [None]:
lightly_train.train_instance_segmentation(
    out="out/my_experiment",
    model="dinov3/vits16-eomt-inst-coco",
    steps=100,  # Small number of steps for demonstration, default is 90_000.
    batch_size=4,  # Small batch size for demonstration, default is 16.
    data={
        "path": "coco128-seg",
        "train": "images/train2017",
        "val": "images/val2017",
        "names": {
            0: "person",
            1: "bicycle",
            2: "car",
            3: "motorcycle",
            4: "airplane",
            5: "bus",
            6: "train",
            7: "truck",
            8: "boat",
            9: "traffic light",
            10: "fire hydrant",
            11: "stop sign",
            12: "parking meter",
            13: "bench",
            14: "bird",
            15: "cat",
            16: "dog",
            17: "horse",
            18: "sheep",
            19: "cow",
            20: "elephant",
            21: "bear",
            22: "zebra",
            23: "giraffe",
            24: "backpack",
            25: "umbrella",
            26: "handbag",
            27: "tie",
            28: "suitcase",
            29: "frisbee",
            30: "skis",
            31: "snowboard",
            32: "sports ball",
            33: "kite",
            34: "baseball bat",
            35: "baseball glove",
            36: "skateboard",
            37: "surfboard",
            38: "tennis racket",
            39: "bottle",
            40: "wine glass",
            41: "cup",
            42: "fork",
            43: "knife",
            44: "spoon",
            45: "bowl",
            46: "banana",
            47: "apple",
            48: "sandwich",
            49: "orange",
            50: "broccoli",
            51: "carrot",
            52: "hot dog",
            53: "pizza",
            54: "donut",
            55: "cake",
            56: "chair",
            57: "couch",
            58: "potted plant",
            59: "bed",
            60: "dining table",
            61: "toilet",
            62: "tv",
            63: "laptop",
            64: "mouse",
            65: "remote",
            66: "keyboard",
            67: "cell phone",
            68: "microwave",
            69: "oven",
            70: "toaster",
            71: "sink",
            72: "refrigerator",
            73: "book",
            74: "clock",
            75: "vase",
            76: "scissors",
            77: "teddy bear",
            78: "hair drier",
            79: "toothbrush",
        },
    },
)

Once training completes, the final model checkpoint is saved in `out/my_experiment/exported_models/exported_last.pt`. If you have a validation dataset, the best model according to the validation mask mAP is saved in `out/my_experiment/exported_models/exported_best.pt`.

In [None]:
model = lightly_train.load_model("out/my_experiment/exported_models/exported_last.pt")

In [None]:
prediction = model.predict("image.jpg")

image = read_image("image.jpg")
masks = prediction["masks"]
image_with_masks = draw_segmentation_masks(
    image,
    masks=masks,
    alpha=1.0,
)
plt.imshow(image_with_masks.permute(1, 2, 0))
plt.axis("off")
plt.show()