# Test for SegFormer usage

## Libraries

In [None]:
! pip install transformers datasets accelerate evaluate pillow

In [None]:
from transformers import pipeline
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from transformers import Trainer, TrainingArguments, default_data_collator

## Run the built-in segmentation pipeline

In [None]:
# Load a SegFormer-B0 model already fine-tuned on ADE20K
segmenter = pipeline(
    "image-segmentation",
    model="optimum/segformer-b0-fintuned-ade-512-512",
) # uses ONNXRuntime under the hood for speed :contentReference[oaicite:1]{index=1}

# Segment your image (PIL Image, NumPy array, or URL)
output = segmenter("path/to/your/image.jpg")

# output is a list of dicts: [{"label":"water","mask":<PIL.Image>}, ...]
for obj in output:
    print(obj["label"], obj["score"])
    obj["mask"].save(f"{obj['label']}.png")

## Fine-tune SegFormer on Datasets

### Dataset preparation

In [None]:
# Dataset on huggingface (should prepare from online download)
# Structure a dataset dict or use the datasets library to load images + mask PNGs, with columns {"image":…, "label":…}

### Preprocess & tokenize

In [None]:
def preprocess(examples):
    imgs = [img.convert("RGB") for img in examples["image"]]
    masks = examples["label"]  # (H×W) integer mask
    inputs = feature_extractor(
        images=imgs,
        segmentation_maps=masks,
        return_tensors="pt"
    )
    return {
        "pixel_values": inputs["pixel_values"],
        "labels": inputs["labels"],
    }

feature_extractor = SegformerFeatureExtractor.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512"
)
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512",
    num_labels=3,                    # your classes: water/sky/obstacle
    ignore_mismatched_sizes=True,    # in case you change decoder head size
)

tokenized = dataset.map(
    preprocess, batched=True, remove_columns=dataset.column_names
)

### Setup Trainer and train

In [None]:
args = TrainingArguments(
    "segformer-maritime",
    evaluation_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    save_strategy="epoch",
    logging_steps=50,
    push_to_hub=False,
)
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=default_data_collator,
    tokenizer=feature_extractor,
)

In [None]:
trainer.train()