In [None]:
import albumentations as A
import evaluate
import json
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image as PImage
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor


### Download and Check Dataset

In [None]:
def color_palette():
    """
    Abbreviated version of ADE20k.
    """
    return [
        [120, 120, 120], [4, 200, 4], [180, 120, 120], [6, 230, 230],
        [80, 50, 50], [120, 120, 80], [140, 140, 140], [204, 5, 255]
        ]

palette = color_palette()

In [None]:
def add_segmentations_to_image(img, segs):
    color_segmentation_map = np.zeros((segs.shape[0], segs.shape[1], 3), dtype=np.uint8)
    for label, color in enumerate(palette):
        color_segmentation_map[segs == label, :] = color
    img_mask = np.array(img) * 0.5 + color_segmentation_map * 0.5
    return img_mask.astype(np.uint8)


def visualize_mask(masks, labels, label_name):
  print("Label:", label_name)
  idx = labels.index(label_name)

  visual_mask = (masks[idx].bool().numpy() * 255).astype(np.uint8)
  return PImage.fromarray(visual_mask)

In [None]:
dataset_id = f"thiagohersan/satellite-trees-dataset"

In [None]:
id2label = json.load(open(hf_hub_download(dataset_id, "id2label.json", repo_type="dataset"), "r"))
id2label = {int(k):v for k,v in id2label.items()}
label2id = {v:int(k) for k,v in id2label.items()}

In [None]:
dataset = load_dataset(dataset_id)
dataset = dataset.shuffle(seed=101010)
dataset = dataset["train"].train_test_split(test_size=0.2)

train_ds = dataset["train"]
test_ds = dataset["test"]

In [None]:
example = train_ds[0]
ex_image = example['pixel_values']
ex_segmentation_map = np.array(example['label'])
ex_labels = [id2label[label] for label in np.unique(ex_segmentation_map)]

np.unique(ex_segmentation_map, return_counts=True), ex_labels

In [None]:
ex_im_seg = add_segmentations_to_image(ex_image, ex_segmentation_map)
plt.figure(figsize=(15, 10))
plt.imshow(ex_im_seg)
plt.show()

### Create PyTorch Dataset

In [None]:
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

train_transform = A.Compose([
    A.Resize(width=512, height=512),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

test_transform = A.Compose([
    A.Resize(width=512, height=512),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),

])

In [None]:
class ImageSegmentationDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        original_image = np.array(self.dataset[idx]['pixel_values'])
        original_segmentation_map = np.array(self.dataset[idx]['label'])

        transformed = self.transform(image=original_image, mask=original_segmentation_map)
        image, segmentation_map = transformed['image'], transformed['mask']

        image = image.transpose(2,0,1) # convert to C, H, W

        return image, segmentation_map, original_image, original_segmentation_map

In [None]:
train_dataset = ImageSegmentationDataset(train_ds, transform=train_transform)
test_dataset = ImageSegmentationDataset(test_ds, transform=test_transform)

In [None]:
ds_image, ds_segmentation_map, _, _ = train_dataset[10]

ds_labels = [id2label[label] for label in np.unique(ds_segmentation_map)]

print(ds_image.shape, ds_segmentation_map.shape)
print(np.unique(ds_segmentation_map, return_counts=True), ds_labels)

In [None]:
PImage.fromarray(np.moveaxis(ds_image.astype(np.uint8), 0, -1))

In [None]:
ds_unnormalized_image = (ds_image * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
ds_unnormalized_image = (ds_unnormalized_image * 255).astype(np.uint8)
ds_unnormalized_image = np.moveaxis(ds_unnormalized_image, 0, -1)

In [None]:
PImage.fromarray(ds_unnormalized_image)

In [None]:
ds_im_seg = add_segmentations_to_image(ds_unnormalized_image, ds_segmentation_map)
plt.figure(figsize=(15, 10))
plt.imshow(ds_im_seg)
plt.show()

### Create PyTorch DataLoader

In [None]:
preprocessor = MaskFormerImageProcessor(
    ignore_index=0, reduce_labels=False,
    do_resize=False, do_rescale=False, do_normalize=False)

def collate_fn(batch):
    inputs = list(zip(*batch))
    images = inputs[0]
    segmentation_maps = inputs[1]

    batch = preprocessor(
        images,
        segmentation_maps=segmentation_maps,
        return_tensors="pt",
    )

    batch["original_images"] = inputs[2]
    batch["original_segmentation_maps"] = inputs[3]
    
    return batch

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

In [None]:
batch = next(iter(train_dataloader))
dl_idx = 0

for k,v in batch.items():
  try:
    print(k,v.shape)
  except:
    print(k,v[dl_idx].shape)
  
dl_labels = [id2label[label] for label in batch["class_labels"][dl_idx].tolist()]

print(batch["mask_labels"][dl_idx].shape, dl_labels)

In [None]:
dl_image = batch['pixel_values'][dl_idx].numpy()

PImage.fromarray(np.moveaxis(dl_image.astype(np.uint8), 0, -1))

In [None]:
dl_unnormalized_image = (dl_image * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
dl_unnormalized_image = (dl_unnormalized_image * 255).astype(np.uint8)
dl_unnormalized_image = np.moveaxis(dl_unnormalized_image, 0, -1)
PImage.fromarray(dl_unnormalized_image)

In [None]:
visualize_mask(batch["mask_labels"][dl_idx], dl_labels, "tree")

### Build Model

In [None]:
model = MaskFormerForInstanceSegmentation.from_pretrained(
    "facebook/maskformer-swin-base-ade",
    id2label=id2label, ignore_mismatched_sizes=True
)

In [None]:
outputs = model(
    batch["pixel_values"],
    class_labels=batch["class_labels"],
    mask_labels=batch["mask_labels"]
)

In [None]:
outputs.loss

### Train

In [None]:
metric = evaluate.load("mean_iou")