# 1. Setup


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
hf_username = "samitizerxu"

In [None]:
import wandb
wandb.login()

In [None]:
import os
os.environ["WANDB_PROJECT"]="kelp-segmentation"

# 2. Dataset Loading


In [None]:
from datasets import load_dataset

ds = load_dataset("samitizerxu/kelp_data_rgbaa_swin_nir")

In [None]:
import torch
ds = ds.shuffle(seed=1)
ds = ds["train"].train_test_split(test_size=0.3, seed=1)
train_ds = ds["train"]
test_ds = ds["test"]
train_orig_ds = ds["train"]
test_orig_ds = ds["test"]

torch.cuda.empty_cache()


In [None]:
import numpy as np

In [None]:
for i in range(5):
    display(train_ds[i]['pixel_values'])

## Image processor & data augmentation

In [None]:
from torchvision.transforms import ColorJitter, RandomAffine, InterpolationMode, RandomHorizontalFlip, RandomVerticalFlip, Compose
from transformers import (
    SegformerImageProcessor,
)
import torch

processor = SegformerImageProcessor()

def train_transforms(example_batch):
    state = torch.get_rng_state()
    transform_fn = Compose([
        RandomAffine(degrees=90,translate=(0.3,0.3),scale=(0.7,1.3),interpolation=InterpolationMode.BILINEAR ),
        RandomHorizontalFlip(p=0.5),
        RandomVerticalFlip(p=0.5)
    ])
    images = [transform_fn(x) for x in example_batch['pixel_values']]
    torch.set_rng_state(state)
    transform_fn = Compose([
        RandomAffine(degrees=90,translate=(0.3,0.3),scale=(0.7,1.3),interpolation=InterpolationMode.BILINEAR ),
        RandomHorizontalFlip(p=0.5),
        RandomVerticalFlip(p=0.5)
    ])
    labels = [transform_fn(x) for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs

# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

In [None]:
from PIL import Image

# 3. Fine-tune


In [None]:
from transformers import SegformerModel, SegformerDecodeHead, SegformerPreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from typing import Optional, Union, Tuple

class SegformerForKelpSemanticSegmentation(SegformerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.segformer = SegformerModel(config)
        self.decode_head = SegformerDecodeHead(config)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> image_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
        >>> model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)
        >>> list(logits.shape)
        [1, 150, 128, 128]
        ```"""
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )

        outputs = self.segformer(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=True,  # we need the intermediate hidden states
            return_dict=return_dict,
        )

        encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]

        logits = self.decode_head(encoder_hidden_states)

        loss = None
        if labels is not None:
            # upsample logits to the images' original size
            upsampled_logits = nn.functional.interpolate(
                logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            if self.config.num_labels > 1:
                loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
                loss = loss_fct(upsampled_logits, labels)
            elif self.config.num_labels == 1:
                valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
                ratio = (labels == 0).sum() /  torch.max(torch.tensor([(labels == 1).sum(),1]))
                loss_fct = BCEWithLogitsLoss(reduction="none", pos_weight=torch.ones_like(labels).to(device='cuda')*ratio)
                loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
                loss = (loss * valid_mask).mean()
            else:
                raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")

        if not return_dict:
            if output_hidden_states:
                output = (logits,) + outputs[1:]
            else:
                output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SemanticSegmenterOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states if output_hidden_states else None,
            attentions=outputs.attentions,
        )


In [None]:
from transformers import SegformerForSemanticSegmentation, SegformerConfig

id2label = {
    0: 'background',
    1: 'kelp'
}

label2id = {
    'background': 0,
    'kelp': 1,
}
pretrained_model_name = "nvidia/mit-b2"

model = SegformerForKelpSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id,
    semantic_loss_ignore_index=255,
)

## Set up the Trainer

In [None]:
from transformers import TrainingArguments

epochs = 40
lr = 0.00006
batch_size = 8

hub_model_id = "segformer-b2-kelp-rgb-agg-imgaug-jan-27"

training_args = TrainingArguments(
    "segformer-b2-kelp-rgb-agg-imgaug-jan-27",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=5,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=30,
    eval_steps=30,
    report_to='wandb',
    metric_for_best_model='eval_iou_kelp',
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    warmup_ratio=0.2,
    weight_decay=0.1,
    hub_model_id=hub_model_id,
    hub_strategy="end",
)

In [None]:
import torch
from torch import nn
import evaluate
import multiprocessing

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    print("Pred labels 0 sum: ",(pred_labels==0).sum())
    print("Pred labels nonzero sum: ",(pred_labels!=0).sum())
    print("labels 0 sum: ",(labels==0).sum())
    print("labels nonzero sum: ",(labels!=0).sum())
    metrics = metric._compute(
          predictions=pred_labels,
          references=labels,
          num_labels=len(id2label),
          ignore_index=255,
      )

    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()
    print(per_category_accuracy)
    print(per_category_iou)

    metrics.update({"eval_accuracy_kelp": per_category_accuracy[1]})
    metrics.update({"eval_iou_kelp": per_category_iou[1]})
    metrics.update({"eval_accuracy_bg": per_category_accuracy[0]})
    metrics.update({"eval_iou_bg": per_category_iou[0]})
    return metrics

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
hf_dataset_identifier = 'samitizerxu/kelp_data'

In [None]:
trainer.train()

In [None]:
kwargs = {
    "tags": ["vision", "image-segmentation"],
    "finetuned_from": pretrained_model_name,
    "dataset": hf_dataset_identifier,
}

processor.push_to_hub(hub_model_id) 
trainer.push_to_hub(**kwargs)  

# 4. Inference

## Use the model from the hub

In [None]:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor()
model = SegformerForSemanticSegmentation.from_pretrained(f"samitizerxu/{hub_model_id}")

In [None]:
#@title `def sidewalk_palette()`

def my_palette():
    return [
        [0, 0, 0],
        [216, 82, 24],
    ]

In [None]:
import numpy as np

def get_seg_overlay(image, seg):
  color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
  palette = np.array(my_palette())
  for label, color in enumerate(palette):
      color_seg[seg == label, :] = color

  # Show image + mask
  img = np.array(image) * 0.5 + color_seg * 0.5
  img = img.astype(np.uint8)

  return img

In [None]:
import matplotlib.pyplot as plt

for i in range(90,100):
    image = test_ds[i]['pixel_values']
    gt_seg = test_ds[i]['label']
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)

    # First, rescale logits to original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1], # (height, width)
        mode='bilinear',
        align_corners=False
    )

    # Second, apply argmax on the class dimension
    pred_seg = upsampled_logits.argmax(dim=1)[0]

    pred_img = get_seg_overlay(image, pred_seg)
    
    gt_img = get_seg_overlay(image, np.array(gt_seg))

    f, axs = plt.subplots(1, 2)
    f.set_figheight(30)
    f.set_figwidth(50)

    axs[0].set_title("Prediction", {'fontsize': 40})
    axs[0].imshow(pred_img)
    axs[1].set_title("Ground truth", {'fontsize': 40})
    axs[1].imshow(gt_img)