In [None]:
import evaluate
import json
import numpy as np
import torch

from datasets import load_dataset
from huggingface_hub import hf_hub_download
from os import path
from PIL import Image as PImage
from tqdm.auto import tqdm
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor

In [None]:
ade_mean=[0.485, 0.456, 0.406]
ade_std=[0.229, 0.224, 0.225]

palette = [
    [120, 120, 120], [4, 200, 4], [4, 4, 250], [6, 230, 230],
    [80, 50, 50], [120, 120, 80], [140, 140, 140], [204, 5, 255]
]


def np_from_tensor(img_t, mean=[0.,0.,0.], std=[1.,1.,1.]):
    img_t = (img_t * np.array(std)[:, None, None]) + np.array(mean)[:, None, None]
    return np.moveaxis((255 * img_t).numpy().astype(np.uint8), 0, -1)


def add_segmentations_to_image(img, segs):
    color_segmentation_map = np.zeros((segs.shape[0], segs.shape[1], 3), dtype=np.uint8)
    for label, color in enumerate(palette):
        color_segmentation_map[segs == label, :] = color
    img_mask = np.array(img) * 0.5 + color_segmentation_map * 0.5
    return img_mask.astype(np.uint8)


def mask_from_label(masks, labels, label_name):
  print("Label:", label_name)
  idx = labels.index(label_name)

  visual_mask = (masks[idx].bool().numpy() * 255).astype(np.uint8)
  return visual_mask


def add_mask_label_to_image(img, mask_label, label_idx):
    img_mask_label = np.zeros((mask_label.shape[0], mask_label.shape[1], 3), dtype=np.uint8)
    img_mask_label[mask_label == 255, :] = palette[label_idx]
    img_mask_label = 0.5 * img + 0.5 * img_mask_label
    return img_mask_label.astype(np.uint8)


In [None]:
dataset_id = f"thiagohersan/satellite-trees"
base_model_id = f"facebook/maskformer-swin-base-ade"
result_model_id = f"maskformer-satellite-trees"

In [None]:
id2label = json.load(open(hf_hub_download(dataset_id, "id2label.json", repo_type="dataset"), "r"))
id2label = {int(k):v for k,v in id2label.items()}
label2id = {v:int(k) for k,v in id2label.items()}

In [None]:
model = MaskFormerForInstanceSegmentation.from_pretrained(
    base_model_id,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

### Download and Check Dataset

In [None]:
dataset = load_dataset(dataset_id)
dataset = dataset["train"].train_test_split(test_size=0.2, shuffle=True, seed=1010)

train_ds = dataset["train"]
test_ds = dataset["test"]

### Create PyTorch Dataset

In [None]:
train_transform = T.Compose([
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.25),
    T.RandomPosterize(bits=2, p=0.2),
    T.RandomAdjustSharpness(sharpness_factor=3, p=0.2),
    T.RandomAutocontrast(p=0.3),
    T.RandomEqualize(p=0.3),
    T.ToTensor(),
    T.Normalize(mean=ade_mean, std=ade_std)
])

test_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=ade_mean, std=ade_std)
])

In [None]:
class ImageSegmentationDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        original_image = np.array(self.dataset[idx]['pixel_values'])
        original_segmentation_map = np.array(self.dataset[idx]['label'])

        image = self.transform(self.dataset[idx]['pixel_values'])
        segmentation_map = original_segmentation_map

        return image, segmentation_map, original_image, original_segmentation_map

In [None]:
train_dataset = ImageSegmentationDataset(train_ds, transform=train_transform)
test_dataset = ImageSegmentationDataset(test_ds, transform=test_transform)

### Create PyTorch DataLoader

In [None]:
preprocessor = MaskFormerImageProcessor(
    do_resize=False,
    do_normalize=False,
    do_rescale=False,
    ignore_index=255,
    reduce_labels=False
)

def collate_fn(batch):
    inputs = list(zip(*batch))
    images = inputs[0]
    segmentation_maps = inputs[1]

    batch = preprocessor(
        images=images,
        segmentation_maps=segmentation_maps,
        return_tensors="pt",
    )

    batch["original_images"] = list(inputs[2])
    batch["original_segmentation_maps"] = list(inputs[3])

    return batch

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [None]:
batch = next(iter(train_dataloader))
b_idx = 0

for k,v in batch.items():
  try:
    print(k,v.shape, v.dtype)
  except:
    print(k,v[b_idx].shape, v[b_idx].dtype)
  
b_labels = [id2label[label] for label in batch["class_labels"][b_idx].tolist()]

print(b_labels)

In [None]:
PImage.fromarray(np_from_tensor(batch['pixel_values'][b_idx]))

In [None]:
PImage.fromarray(mask_from_label(batch["mask_labels"][b_idx], b_labels, 'vegetation'))

In [None]:
PImage.fromarray(
    add_mask_label_to_image(
        np_from_tensor(batch['pixel_values'][b_idx]),
        mask_from_label(batch["mask_labels"][b_idx], b_labels, 'water'),
        b_labels.index('water')
    )
)

In [None]:
PImage.fromarray(batch['original_images'][b_idx])

In [None]:
PImage.fromarray(
    add_segmentations_to_image(
        batch['original_images'][b_idx],
        batch["original_segmentation_maps"][b_idx]
    )
)

### Check Model

In [None]:
outputs = model(
    batch["pixel_values"],
    class_labels=batch["class_labels"],
    mask_labels=batch["mask_labels"]
)

outputs.loss

### Train

In [None]:
metric = evaluate.load("mean_iou")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

running_loss = 0.0
num_samples = 0

for epoch in range(16):
    print("Epoch:", epoch)

    model.train()
    for idx, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()

        outputs = model(
            pixel_values=batch["pixel_values"].to(device),
            mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
            class_labels=[labels.to(device) for labels in batch["class_labels"]],
        )

        loss = outputs.loss
        loss.backward()

        batch_size = batch["pixel_values"].size(0)
        running_loss += loss.item()
        num_samples += batch_size

        if idx % 10 == 0:
            print("Loss: ", running_loss/num_samples)

        optimizer.step()

    model.eval()
    for idx, batch in enumerate(tqdm(test_dataloader)):
        if idx > 7:
            break

        with torch.no_grad():
            outputs = model(pixel_values=batch["pixel_values"].to(device))

        original_images = batch["original_images"]
        target_sizes = [(image.shape[0], image.shape[1]) for image in original_images]

        predicted_segmentation_maps = preprocessor.post_process_semantic_segmentation(
            outputs,
            target_sizes=target_sizes
        )

        predicted_segmentation_maps = [psm.cpu() for psm in predicted_segmentation_maps]
        ground_truth_segmentation_maps = batch["original_segmentation_maps"]

        metric.add_batch(references=ground_truth_segmentation_maps, predictions=predicted_segmentation_maps)

    test_metrics = metric.compute(num_labels=len(id2label), ignore_index=255, reduce_labels=False)
    print("Mean IoU:", test_metrics['mean_iou'], "Vegetation IoU:", test_metrics['per_category_iou'][label2id['vegetation']])

### Push Model to Hub

In [None]:
model.save_pretrained(path.join("models", result_model_id))
preprocessor.save_pretrained(path.join("models", result_model_id))

In [None]:
hub_preprocessor = MaskFormerImageProcessor(
    do_resize=False,
    do_normalize=False,
    do_rescale=True,
    ignore_index=255,
    reduce_labels=False
)

model.push_to_hub(result_model_id)
hub_preprocessor.push_to_hub(result_model_id)

### Test

In [None]:
batch = next(iter(test_dataloader))
b_idx = 0

for k,v in batch.items():
  try:
    print(k,v.shape, v.dtype)
  except:
    print(k,v[b_idx].shape, v[b_idx].dtype)
  
b_labels = [id2label[label] for label in batch["class_labels"][b_idx].tolist()]

print(b_labels)

In [None]:
with torch.no_grad():
  outputs = model(batch["pixel_values"].to(device))

In [None]:
original_images = batch["original_images"]
target_sizes = [(image.shape[0], image.shape[1]) for image in original_images]
predicted_segmentation_maps = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)

In [None]:
PImage.fromarray(batch["original_images"][b_idx])

In [None]:
PImage.fromarray(
    add_segmentations_to_image(
        batch["original_images"][b_idx],
        batch["original_segmentation_maps"][b_idx]
    )
)

In [None]:
PImage.fromarray(
    add_segmentations_to_image(
        batch["original_images"][b_idx],
        predicted_segmentation_maps[b_idx].cpu().numpy()
    )
)