## Fine-tune YOLO 11 model

In [None]:
from ultralytics import YOLO

model = YOLO("yolo11m-seg.pt") # start from the pretrained model
model.train(
    # paths
    project=".",
    name="yolo_training_run",
    data="../data_meta/yolo_train_ds.yaml",

    # training params
    epochs=300,
    imgsz=640,
    device="cuda:1",

    # augmentation params
    degrees=360,
    copy_paste=0.4,
    copy_paste_mode="mixup",
    shear=10,
    perspective=0.0003,
    flipud=0.5,
    hsv_h=0.3,
    bgr=0.1,
)


### Sanity check
Check some image as the model

In [None]:
results = model(
    "../data/yolo_train/images/val/TOU-J-3/rep_06/1000269_2022_05_23_12_01_21-6-23-TB06-RGB1_pot_D2_TOU-J-3-06.png", retina_masks=True
)
results[0].show()

## Yolo v8

In [None]:
from ultralytics import YOLO
model = YOLO("yolov8m-seg.pt")
model.train(
    project="./runs/segment/yolov8",
    name="yolo_v8_training_run",
    data="../data_meta/yolo_train_ds.yaml",
    epochs=300,
    imgsz=640,
    device="cuda:1",
    patience=20,
    degrees=360,
    copy_paste=0.4,
    copy_paste_mode="mixup",
    shear=10,
    perspective=0.0003,
    flipud=0.5,
    hsv_h=0.3,
    bgr=0.1
)

## Sanity check
- Check some image as the model

In [None]:
results = model(
    "../data/yolo_train/images/val/TOU-J-3/rep_06/1000269_2022_05_23_12_01_21-6-23-TB06-RGB1_pot_D2_TOU-J-3-06.png", retina_masks=True
)
results[0].show()

# Fine-tune SAM2 model

### Set data parameters

In [None]:
import sys
import os
os.environ["HYDRA_FULL_ERROR"] = "1"
sys.path.insert(0, "../src")
SAM_PATH = "../thirdparty/segment-anything-2/"
sys.path.insert(1, SAM_PATH)
from omegaconf import OmegaConf
from hydra.utils import instantiate
from hydra import compose, initialize, initialize_config_module
from PIL import Image as PILImage
from saveload import read_image, read_masks, mask_joined_to_masks_dict, _imread_func
from masks import OUT_OF_LIST_COLOR, DEFAULT_COLORS
import torch
import pandas as pd
import numpy as np
from dataset import load_dataset
from training.dataset.vos_raw_dataset import VOSRawDataset, VOSVideo, VOSFrame
OmegaConf.register_new_resolver("times", lambda a, b: a * b)
OmegaConf.register_new_resolver("divide", lambda a, b: a // b)

In [None]:
IMAGES_ROOT = (
    ""  # root folder with images
)
MASKS_ROOT = (
    ""  # root folder with labels
)
ds = load_dataset(images_root=IMAGES_ROOT, masks_root=MASKS_ROOT)
ds = ds[ds["nn_role"] == "train"]
print(f"Prepared ds with {len(ds)} images")

In [None]:
def mask_joined_to_masks_dict_no_error(mask: np.ndarray) -> dict[int, dict]:
    """Split joined masks to separate masks.

    Args:
        mask (np.ndarray): joined masks

    Returns:
        list: masks in SAM2 format
    """
    masks = {}

    all_masks_colors = set(
        tuple(x.tolist()) for x in np.unique(mask.reshape(-1, 3), axis=0)
    )
    strange_colors = []
    for c in all_masks_colors:
        if not (c in DEFAULT_COLORS or c == OUT_OF_LIST_COLOR or c == (0, 0, 0)):
            print(f"Problem with parsing color {c}")
            strange_colors.append(c)
        # assert (
        #     c in DEFAULT_COLORS or c == OUT_OF_LIST_COLOR or c == (0, 0, 0)
        # ), f"Unknown color {c}"

    for i, color in enumerate(DEFAULT_COLORS + [OUT_OF_LIST_COLOR] + strange_colors):
        if color not in all_masks_colors:
            pass

        mask_i = np.all(mask == color, axis=-1)
        if mask_i.sum() > 0:
            masks[i] = {"segmentation": mask_i, "_detection_index": i}
    return masks


class LeafPalettisedPNGSegmentLoader:
    def __init__(self, video_ds_part: pd.DataFrame):
        """
        SegmentLoader for datasets with masks stored as palettised PNGs.
        video_png_root: the folder contains all the masks stored in png
        """
        self.video_ds_part2 = video_ds_part

        self.frame_id_to_png_filename = {}
        for _, row in self.video_ds_part2.iterrows():
            self.frame_id_to_png_filename[row["image_num"]] = row["mask_path"]

    def load(self, frame_id):
        """
        load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png')
        Args:
            frame_id: int, define the mask path
        Return:
            binary_segments: dict
        """
        mask_path = self.frame_id_to_png_filename[frame_id]
        masks = mask_joined_to_masks_dict_no_error(_imread_func(mask_path))

        binary_segments = {}
        for i, m in masks.items():
            # binary_segments[m['detection_index']] = torch.from_numpy(m['segmentation'])
            binary_segments[i] = torch.from_numpy(m["segmentation"])

        return binary_segments

    def __len__(self):
        raise NotImplementedError()
        return


class LeafPNGRawDataset(VOSRawDataset):
    leaf_ds = ds

    def __init__(
        self,
        img_folder,
        gt_folder,
        file_list_txt=None,
        excluded_videos_list_txt=None,
        sample_rate=1,
        is_palette=True,
        single_object_mode=False,
        truncate_video=-1,
        frames_sampling_mult=False,
    ):
        self.img_folder = img_folder
        self.gt_folder = gt_folder
        self.sample_rate = sample_rate
        self.is_palette = is_palette
        self.single_object_mode = single_object_mode
        self.truncate_video = truncate_video

        assert self.img_folder == None, f"img_folder {self.img_folder} is not None"
        assert self.gt_folder == None, f"gt_folder {self.gt_folder} is not None"
        assert file_list_txt == None, f"file_list_txt {file_list_txt} is not None"
        assert (
            frames_sampling_mult == False
        ), f"frames_sampling_mult {frames_sampling_mult} is not False"
        assert (
            self.single_object_mode == False
        ), f"single_object_mode {self.single_object_mode} is not False"
        assert self.is_palette == True, f"is_palette {self.is_palette} is not True"
        assert (
            self.truncate_video == -1
        ), f"truncate_video {self.truncate_video} is not -1"

        # Read the subset defined in file_list_txt
        self.video_names2 = sorted(
            set(f"{row['plant']}/{row['rep']}" for (_, row) in self.leaf_ds.iterrows())
        )

    def get_video(self, idx):
        """
        Given a VOSVideo object, return the mask tensors.
        """
        # print("called")
        video_name = self.video_names2[idx]
        plant, rep = video_name.split("/")
        video_part = self.leaf_ds[
            (self.leaf_ds["plant"] == plant) & (self.leaf_ds["rep"] == rep)
        ]
        # if self.is_palette:
        segment_loader = LeafPalettisedPNGSegmentLoader(video_ds_part=video_part)

        frames = []
        for _, row in video_part.iterrows():
            fid = row["image_num"]
            frames.append(VOSFrame(fid, image_path=row["image_path"]))
        video = VOSVideo(video_name, idx, frames)
        return video, segment_loader

    def __len__(self):
        return len(self.video_names2)

### Build trainer

In [None]:
with initialize(config_path=os.path.join(SAM_PATH, "sam2"), version_base="1.2"):
    # Compose the configuration
    cfg = compose(
        config_name="configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml"
    )
    cfg.launcher.experiment_log_dir = "sam2_training_run"
    video_ds = cfg.trainer.data.train.datasets[0].dataset.datasets[0].video_dataset
    video_ds._target_ = "__main__.LeafPNGRawDataset"
    video_ds.img_folder = None
    video_ds.gt_folder = None
    video_ds.file_list_txt = None
    cfg.scratch.num_train_workers = 5
    cfg.trainer.checkpoint.model_weight_initializer.state_dict.checkpoint_path = (
        SAM_PATH + "/checkpoints/sam2.1_hiera_base_plus.pt"
    )
    cfg.trainer.max_epochs = 20
    # cfg.trainer.accelerator = "cuda:1" # is set in local_rank
    # cfg.trainer.max_epochs = cfg.scratch.num_epochs

"""Single GPU process"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(8085)
os.environ["RANK"] = str(0)  # hm, what does this parameter do?
os.environ["LOCAL_RANK"] = str(1)
os.environ["WORLD_SIZE"] = str(1)
trainer = instantiate(cfg.trainer, _recursive_=False)

### Train model

In [None]:
trainer.run()

### Sanity check

In [None]:
from sam2.build_sam import build_sam2

sam2 = build_sam2(
    "configs/sam2.1/sam2.1_hiera_b+.yaml",
    "sam2_training_run/checkpoints/checkpoint.pt",
    device="cuda:1",
    apply_postprocessing=False,
)

# Detectron 2

In [None]:
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
import os
import cv2
import numpy as np
import torch, detectron2
from datetime import datetime
from detectron2.data.datasets import register_coco_instances
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.engine import DefaultTrainer
import matplotlib.pyplot as plt
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
outputs = predictor(image)

In [None]:
# Create the visualizer
visualizer = Visualizer(image[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
out = visualizer.draw_instance_predictions(outputs["instances"].to("cpu"))
# Convert the output to a NumPy array for matplotlib
output_image = out.get_image()[:, :, ::-1]
# Display the image using matplotlib
plt.figure(figsize=(12, 8))  # Adjust the figure size as needed
plt.imshow(output_image)
plt.axis("off")  # Hide axis
plt.show()

In [None]:
# Define dataset name
DATA_SET_NAME = "leaf"  # Replace with your dataset name

# TRAIN SET - Manually specify the paths
TRAIN_DATA_SET_NAME = f"{DATA_SET_NAME}-train"
TRAIN_DATA_SET_IMAGES_DIR_PATH = ""  # Replace with the path to your training images directory
TRAIN_DATA_SET_ANN_FILE_PATH   = ""  # Replace with your training JSON file path

register_coco_instances(
    name=TRAIN_DATA_SET_NAME, 
    metadata={}, 
    json_file=TRAIN_DATA_SET_ANN_FILE_PATH, 
    image_root=TRAIN_DATA_SET_IMAGES_DIR_PATH
)

# TEST SET - Manually specify the paths
TEST_DATA_SET_NAME = f"{DATA_SET_NAME}-test"
TEST_DATA_SET_IMAGES_DIR_PATH = ""  # Replace with the path to your training images directory
TEST_DATA_SET_ANN_FILE_PATH   = ""  # Replace with your test JSON file path

register_coco_instances(
    name=TEST_DATA_SET_NAME, 
    metadata={}, 
    json_file=TEST_DATA_SET_ANN_FILE_PATH, 
    image_root=TEST_DATA_SET_IMAGES_DIR_PATH
)

# VALIDATION SET - Manually specify the paths
VALID_DATA_SET_NAME = f"{DATA_SET_NAME}-valid"
VALID_DATA_SET_IMAGES_DIR_PATH = "" 
VALID_DATA_SET_ANN_FILE_PATH   = ""  # Replace with your test JSON file path

register_coco_instances(
    name=VALID_DATA_SET_NAME, 
    metadata={}, 
    json_file=VALID_DATA_SET_ANN_FILE_PATH, 
    image_root=VALID_DATA_SET_IMAGES_DIR_PATH
)

[
    data_set
    for data_set
    in MetadataCatalog.list()
    if data_set.startswith(DATA_SET_NAME)
]
metadata = MetadataCatalog.get(TRAIN_DATA_SET_NAME)


In [None]:
# Load metadata and dataset
metadata = MetadataCatalog.get(TRAIN_DATA_SET_NAME)
dataset_train = DatasetCatalog.get(TRAIN_DATA_SET_NAME)

# Load a single dataset entry
dataset_entry = dataset_train[0]
image = cv2.imread(dataset_entry["file_name"])

# Create visualizer instance
visualizer = Visualizer(
    image[:, :, ::-1],
    metadata=metadata,
    scale=0.8,
    instance_mode=ColorMode.IMAGE_BW  # Optional: Black and white background
)

# Draw dataset dictionary
out = visualizer.draw_dataset_dict(dataset_entry)

# Display the image using matplotlib
output_image = out.get_image()[:, :, ::-1]
plt.figure(figsize=(12, 8))  # Adjust the figure size as needed
plt.imshow(output_image)
plt.axis("off")  # Hide axis
plt.show()

In [None]:
# HYPERPARAMETERS
ARCHITECTURE = "mask_rcnn_R_101_FPN_3x"
CONFIG_FILE_PATH = f"COCO-InstanceSegmentation/{ARCHITECTURE}.yaml"
MAX_ITER = 2000
EVAL_PERIOD = 200
BASE_LR = 0.001
NUM_CLASSES = 3
# OUTPUT DIR
OUTPUT_DIR_PATH = os.path.join(
    DATA_SET_NAME, 
    ARCHITECTURE, 
    datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
)
os.makedirs(OUTPUT_DIR_PATH, exist_ok=True)

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(CONFIG_FILE_PATH))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CONFIG_FILE_PATH)
cfg.DATASETS.TRAIN = (TRAIN_DATA_SET_NAME,)
cfg.DATASETS.TEST = (TEST_DATA_SET_NAME,)
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.TEST.EVAL_PERIOD = EVAL_PERIOD
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.INPUT.MASK_FORMAT='bitmask'
cfg.SOLVER.BASE_LR = BASE_LR
cfg.SOLVER.MAX_ITER = MAX_ITER
cfg.MODEL.ROI_HEADS.NUM_CLASSES = NUM_CLASSES
cfg.OUTPUT_DIR = OUTPUT_DIR_PATH

In [None]:
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

In [None]:
cfg = get_cfg()
ARCHITECTURE = "mask_rcnn_R_101_FPN_3x"
# COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml
CONFIG_FILE_PATH = cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = "./leaf/mask_rcnn_R_101_FPN_3x/2025-02-01-10-12-17/model_final.pth"  # Path to trained model
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # Adjust threshold as needed
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"  # Use GPU if available
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
predictor = DefaultPredictor(cfg)

num_samples = 20
dataset_valid = DatasetCatalog.get(VALID_DATA_SET_NAME)
metadata = MetadataCatalog.get(VALID_DATA_SET_NAME)
dataset_valid = DatasetCatalog.get(VALID_DATA_SET_NAME)
samples_to_visualize = dataset_valid[:num_samples]

# Iterate through the samples and visualize
for i, d in enumerate(samples_to_visualize):
    img = cv2.imread(d["file_name"])
    outputs = predictor(img)
    instances = outputs["instances"].to("cpu")
    pred_masks = instances.pred_masks.numpy().astype(np.int32)  # Convert masks
    pred_boxes = instances.pred_boxes.tensor.numpy().astype(np.int32)  # Convert boxes
    pred_scores = instances.scores.numpy()  # Get confidence scores
    pred_classes = instances.pred_classes.numpy()  # Get class labels
    print("Image shape:", img.shape)
    print("pred_boxes shape:", pred_masks.shape)