# 1. Setup


In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
hf_username = "samitizerxu"

In [None]:
import wandb
wandb.login()

In [None]:
import os
os.environ["WANDB_PROJECT"]="kelp-segmentation"

# 2. Dataset Loading


In [None]:
from datasets import load_dataset

ds = load_dataset("kelp_data", name="rgb_int",trust_remote_code=True)

In [None]:
ds = ds.shuffle(seed=1)
ds = ds["train"].train_test_split(test_size=0.1)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
ds['train']['pixel_values'][0].shape

In [None]:
ds['train']['pixel_values'][0].shape

## Image processor & data augmentation

In [None]:
from torchvision.transforms import ColorJitter
from transformers import (
    SegformerImageProcessor,
)

processor = SegformerImageProcessor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

# 3. Fine-tune


In [None]:
from transformers import SegformerForSemanticSegmentation, SegformerConfig

id2label = {
    1: 'kelp',
    0: 'not_kelp'
}

label2id = {
    'kelp': 1,
    'not_kelp': 0
}
pretrained_model_name = "nvidia/mit-b0"

model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id,
    num_channels=5,
    ignore_mismatched_sizes=True
)

## Set up the Trainer

In [None]:
from transformers import TrainingArguments

epochs = 30
lr = 0.00006
batch_size = 8

hub_model_id = "segformer-b0-finetuned-kelp-segments-5-channel-jan-20"

training_args = TrainingArguments(
    "segformer-b0-finetuned-kelp-segments-5-channel-jan-20",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=30,
    eval_steps=30,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="end",
)

In [None]:
import torch
from torch import nn
import evaluate
import multiprocessing

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            reduce_labels=processor.do_reduce_labels,
        )

    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()
    print(per_category_accuracy)
    print(per_category_iou)

    metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

    return metrics

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
hf_dataset_identifier = 'samitizerxu/kelp_data'

In [None]:
ds.push_to_hub(hf_dataset_identifier)

In [None]:
kwargs = {
    "tags": ["vision", "image-segmentation"],
    "finetuned_from": pretrained_model_name,
    "dataset": hf_dataset_identifier,
}

processor.push_to_hub(hub_model_id)
trainer.push_to_hub(**kwargs)

# 4. Inference

## Use the model from the hub

In [None]:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor()
model = SegformerForSemanticSegmentation.from_pretrained("samitizerxu/segformer-b0-finetuned-kelp-segments-jan-18-10am")

In [None]:
#@title `def sidewalk_palette()`

def my_palette():
    """Sidewalk palette that maps each class to RGB values."""
    return [
        [0, 0, 0],
        [216, 82, 24],
    ]

In [None]:
import numpy as np

def get_seg_overlay(image, seg):
  color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
  palette = np.array(my_palette())
  for label, color in enumerate(palette):
      color_seg[seg == label, :] = color

  # Show image + mask
  img = np.array(image) * 0.5 + color_seg * 0.5
  img = img.astype(np.uint8)

  return img

In [None]:
import matplotlib.pyplot as plt

for i in range(90,120):
    image = test_ds[i]['pixel_values']
    gt_seg = test_ds[i]['label']
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)

    # First, rescale logits to original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1], # (height, width)
        mode='bilinear',
        align_corners=False
    )

    # Second, apply argmax on the class dimension
    pred_seg = upsampled_logits.argmax(dim=1)[0]

    pred_img = get_seg_overlay(image, pred_seg)
    
    gt_img = get_seg_overlay(image, np.array(gt_seg))

    f, axs = plt.subplots(1, 2)
    f.set_figheight(30)
    f.set_figwidth(50)

    axs[0].set_title("Prediction", {'fontsize': 40})
    axs[0].imshow(pred_img)
    axs[1].set_title("Ground truth", {'fontsize': 40})
    axs[1].imshow(gt_img)