In [None]:
import evaluate
import json
import numpy as np
import torch

from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image as PImage
from torch import nn
from torchvision import transforms as T
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from transformers import Trainer, TrainingArguments

In [None]:
ade_mean=[0.485, 0.456, 0.406]
ade_std=[0.229, 0.224, 0.225]

palette = [
    [120, 120, 120], [4, 200, 4], [180, 120, 120], [6, 230, 230],
    [80, 50, 50], [120, 120, 80], [140, 140, 140], [204, 5, 255]
]


def np_from_tensor(img_t, mean=[0.,0.,0.], std=[1.,1.,1.]):
    img_t = (img_t * np.array(std)[:, None, None]) + np.array(mean)[:, None, None]
    return np.moveaxis((255 * img_t).numpy().astype(np.uint8), 0, -1)


def mask_from_label(masks, labels, label_name):
  print("Label:", label_name)
  idx = labels.index(label_name)

  visual_mask = (masks[idx].bool().numpy() * 255).astype(np.uint8)
  return visual_mask


def add_mask_label_to_image(img, mask_label, label_idx):
    img_mask_label = np.zeros((mask_label.shape[0], mask_label.shape[1], 3), dtype=np.uint8)
    img_mask_label[mask_label == 255, :] = palette[label_idx]
    img_mask_label = 0.5 * img + 0.5 * img_mask_label
    return img_mask_label.astype(np.uint8)


In [None]:
dataset_id = f"thiagohersan/satellite-trees"
base_model_id = f"facebook/maskformer-swin-base-ade"
result_model_id = f"maskformer-satellite-trees"

In [None]:
id2label = json.load(open(hf_hub_download(dataset_id, "id2label.json", repo_type="dataset"), "r"))
id2label = {int(k):v for k,v in id2label.items()}
label2id = {v:int(k) for k,v in id2label.items()}

In [None]:
model = MaskFormerForInstanceSegmentation.from_pretrained(
    base_model_id,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

### Download Dataset

In [None]:
dataset = load_dataset(dataset_id)
dataset = dataset["train"].train_test_split(test_size=0.2, shuffle=True, seed=1010)

train_ds = dataset["train"]
test_ds = dataset["test"]

### Create MaskFormer Dataset

In [None]:
preprocessor = MaskFormerImageProcessor(
    ignore_index=255,
    reduce_labels=False,
    do_resize=False,
    do_rescale=False,
    do_normalize=False
)

In [None]:
def get_transform(transform):
    def apply_transform(batch_in):
        images = [transform(img) for img in batch_in["pixel_values"]]
        labels = [l for l in batch_in["label"]]

        batch_out = preprocessor(images=images, segmentation_maps=labels, return_tensors="pt")
        return batch_out
    return apply_transform

In [None]:
train_transform = T.Compose([
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.25),
    T.RandomPosterize(bits=2, p=0.2),
    T.RandomAdjustSharpness(sharpness_factor=3, p=0.2),
    T.RandomAutocontrast(p=0.3),
    T.RandomEqualize(p=0.3),
    T.ToTensor(),
    T.Normalize(mean=ade_mean, std=ade_std)
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=ade_mean, std=ade_std)
])

train_ds.set_transform(get_transform(train_transform))
test_ds.set_transform(get_transform(test_transform))

### Check Data

In [None]:
example = train_ds[0]

for k,v in example.items():
  try:
    print(k,v.shape, v.dtype)
  except:
    print(f"{k}[0]",v[0].shape)

ex_labels = [id2label[label] for label in example["class_labels"].tolist()]
print(ex_labels)

In [None]:
PImage.fromarray(np_from_tensor(example['pixel_values']))

In [None]:
PImage.fromarray(mask_from_label(example["mask_labels"], ex_labels, 'tree'))

In [None]:
PImage.fromarray(
    add_mask_label_to_image(
        np_from_tensor(example['pixel_values']),
        mask_from_label(example["mask_labels"], ex_labels, 'tree'),
        1
    )
)

### Train

In [None]:
training_args = TrainingArguments(
    output_dir=f"{result_model_id}-outputs",
    learning_rate=5e-5,
    num_train_epochs=2,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    report_to="tensorboard",
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=result_model_id,
    hub_strategy="end",
)

In [None]:
metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  print("hello")
  print(eval_pred)
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)

    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=255,
            reduce_labels=False)

    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

    return metrics

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()