## Setup

In [None]:
!git clone https://github.com/sithu31296/semantic-segmentation
%cd semantic-segmentation
%pip install -e .

In [None]:
import torch
from torchvision import io
from torchvision import transforms as T
from PIL import Image

def show_image(image):
    if image.shape[2] != 3: image = image.permute(1, 2, 0)
    image = Image.fromarray(image.numpy())
    return image

## Show Available Pretrained Models

In [None]:
from semseg import show_models

show_models()

## Load a Pretrained Model

Download a pretrained model's weights from the result table (ADE20K, CityScapes, ...) and put it in `checkpoints/pretrained/model_name/`.

In [None]:
%pip install -U gdown

In [None]:
import gdown
from pathlib import Path

ckpt = Path('./checkpoints/pretrained/segformer')
ckpt.mkdir(exist_ok=True, parents=True)

url = 'https://drive.google.com/uc?id=1-OmW3xRD3WAbJTzktPC-VMOF5WMsN8XT'
output = './checkpoints/pretrained/segformer/segformer.b3.ade.pth'

gdown.download(url, output, quiet=False)

In [None]:
from semseg.models import *

model = eval('SegFormer')(
    backbone='MiT-B3',
    num_classes=150
)

try:
    model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b3.ade.pth', map_location='cpu'))
except:
    print("Download a pretrained model's weights from the result table.")
model.eval()

print('Loaded Model')

## Simple Image Inference

### Load Image

In [None]:
image_path = './assests/ade/ADE_val_00000049.jpg'
image = io.read_image(image_path)
print(image.shape)
show_image(image)

### Preprocess

In [None]:
# resize
image = T.CenterCrop((512, 512))(image)
# scale to [0.0, 1.0]
image = image.float() / 255
# normalize
image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)
# add batch size
image = image.unsqueeze(0)
image.shape

### Model Forward

In [None]:
with torch.inference_mode():
    seg = model(image)
seg.shape

### Postprocess

In [None]:
seg = seg.softmax(1).argmax(1).to(int)
seg.unique()

In [None]:
from semseg.datasets import *

palette = eval('ADE20K').PALETTE

In [None]:
seg_map = palette[seg].squeeze().to(torch.uint8)
show_image(seg_map)

## Show Available Backbones

In [None]:
from semseg import show_backbones

show_backbones()

## Show Available Heads

In [None]:
from semseg import show_heads

show_heads()

## Show Available Datasets

In [None]:
from semseg import show_datasets

show_datasets()

## Construct a Custom Model

### Choose a Backbone

In [None]:
from semseg.models.backbones import ResNet

backbone = ResNet('18')

In [None]:
# init random input batch
x = torch.randn(2, 3, 224, 224)

In [None]:
# get features from the backbone
features = backbone(x)
for out in features:
    print(out.shape)

### Choose a Head

In [None]:
from semseg.models.heads import UPerHead

head = UPerHead(backbone.channels, 128, num_classes=10)

In [None]:
seg = head(features)
seg.shape

In [None]:
from torch.nn import functional as F
# upsample the output
seg = F.interpolate(seg, size=x.shape[-2:], mode='bilinear', align_corners=False)
seg.shape

Check `semseg/models/custom_cnn.py` and `semseg/models/custom_vit.py` for a complete construction for custom model.