In [None]:
from os import path

def get_body(fpath):
    return path.splitext(path.basename(fpath))[0]

In [None]:
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from torch import nn

def setup(model_name):
    feature_extractor = SegformerFeatureExtractor.from_pretrained(model_name)
    model = SegformerForSemanticSegmentation.from_pretrained(model_name)
    return feature_extractor, model

def predict(feature_extractor, model, image):
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)

    upsampled_logits = nn.functional.interpolate(logits,
                    size=(image.size[1], image.size[0]), # (height, width)
                    mode='bilinear',
                    align_corners=False)
        
    predicted_mask = upsampled_logits.argmax(dim=1).cpu().numpy()
    return predicted_mask[0]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

plt.ioff()

def visualize(labels, image, prediction):
    classes, palette = labels['classes'], labels['palette']
    color_map = {i : k for i, k in enumerate(palette)}
    vis = np.zeros(prediction.shape + (3,))
    for i, c in color_map.items():
        vis[prediction == i] = color_map[i]
    mask = Image.fromarray(vis.astype(np.uint8))
    overlayed_img = Image.blend(image.convert("RGB"), mask.convert("RGB"), 0.5)

    hist, bins = np.histogram(prediction, range(0, len(classes)))
    histbins = sorted([(h, b)for h, b in zip(hist, bins)], reverse=True)
    n_o_pixels = mask.size[0] * mask.size[1]
    stats = [(classes[b], (np.array(palette[b]) / 255).tolist(), r) for h, b in histbins if (r:=(int(h) * 100 / n_o_pixels)) > 0.1]
    kinds, colors, ratios = list(zip(*stats))

    fig = plt.figure(figsize=(20, 10), layout="constrained")
    spec = fig.add_gridspec(2, 2)
    ax00 = fig.add_subplot(spec[0, 0])
    ax01 = fig.add_subplot(spec[0, 1])
    ax0 = fig.add_subplot(spec[1, :])
    ax00.imshow(mask)
    ax01.imshow(overlayed_img)
    ax0.bar(kinds, ratios, color=colors)
    return fig

In [None]:
# xxx_labels.json is mady from the codes on here
# ade_labels.json: https://github.com/NVlabs/SegFormer/blob/master/mmseg/datasets/ade.py
# cityscapes_labels.json: https://github.com/NVlabs/SegFormer/blob/master/mmseg/datasets/cityscapes.py

In [None]:
import json
from functools import partial
import requests

ade_predict = partial(predict, *setup("nvidia/segformer-b0-finetuned-ade-512-512"))
ade_labels = json.load(open('ade_labels.json'))
ade_visualize = partial(visualize, ade_labels)

def ade_analyze(url):
    image = Image.open(requests.get(url, stream=True).raw)
    prediction = ade_predict(image)
    fig = ade_visualize(image, prediction)
    hist, bins = np.histogram(prediction, range(len(ade_labels['classes'])))
    hist = (hist * 100 / (image.size[0] * image.size[1])).astype(np.int32)
    body = get_body(url)
    fig.savefig(f'_segmented/{body}.png')
    json.dump(hist.tolist(), open(f'_segmented/{body}.json', 'w'))

In [None]:
ade_analyze("http://images.cocodataset.org/val2017/000000039769.jpg")

In [None]:
from functools import partial
from PIL import Image
import requests

cc_predict = partial(predict, *setup("nvidia/segformer-b5-finetuned-cityscapes-1024-1024"))
cc_labels = json.load(open('cityscapes_labels.json'))
cc_visualize = partial(visualize, cc_labels)

def cc_analyze(url):
    image = Image.open(url).convert('RGB')
    prediction = cc_predict(image)
    fig = cc_visualize(image, prediction)
    hist, bins = np.histogram(prediction, range(len(cc_labels['classes'])))
    hist = (hist * 100 / (image.size[0] * image.size[1])).astype(np.int32)
    body = get_body(url)
    fig.savefig(f'_segmented/{body}.png')
    json.dump(hist.tolist(), open(f'_segmented/{body}.json', 'w'))

In [None]:
cc_analyze("../scene_capture/_captured/tky-2_0_-80.png")

In [None]:
from glob import glob

# for fname in glob("../scene_capture/_captured/tky-*.png"):
# for fname in glob("../scene_capture/_captured/osk-*.png"):
# for fname in glob("../scene_capture/_captured/ngy-*.png"):
# for fname in glob("../scene_capture/_captured/sap-*.png"):
for fname in glob("../scene_capture/_captured/fuk-*.png"):
    cc_analyze(fname)