In [None]:
import io
import cv2
import json
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import os
import os.path as osp
import torch
import torch.nn as nn
import urllib
from mmseg.core.evaluation import mean_iou
import sys

sys.path.append('../')
from contextlib import redirect_stdout
from easydict import EasyDict
from models.builder import EncoderDecoder
from utils.pyt_utils import load_model
from utils.transforms import normalize
from dataloader.RGBXDataset import image_decoder

trap = io.StringIO()


# PanoContext RGB-D Panoramic Sample

In [None]:
device = 'cuda:0'
config = EasyDict()

config.fold = 'F1'
config.root = '<sfss_mmsi_path>' # TODO: change this to your own path
config.dataset_path = osp.join(config.root, 'datasets', '2D-3D-Semantics-1K')
config.dataset_name = 'Stanford2D3DS'
config.ignore_index = 255
config.image_height = 512
config.image_width = 512
config.norm_mean = np.array([0.485, 0.456, 0.406])
config.norm_std = np.array([0.229, 0.224, 0.225])

config.backbone = 'dual_mit_b2' # TODO: change to 'mit_b2', 'dual_mit_b2'
config.pretrained_model = osp.join(config.root, 'pretrained', 'segformers/mit_b2.pth')
config.decoder = 'DMLPDecoderV2'
config.decoder_embed_dim = 512
config.optimizer = 'AdamW'
config.use_dcns = [True, False, False, False]

config.batch_size = 1
config.rgb = 'camera-rgb-1K'
config.ann = 'camera-semantic-1K'
config.modality_x = ['camera-depth-1K']
config.train_source = osp.join(config.dataset_path, f'train_{config.fold}.txt')
config.eval_source = osp.join(config.dataset_path, f'test_{config.fold}.txt')
config.num_classes = 13
config.train_scale_array = [0.5, 0.75, 1, 1.25, 1.5, 1.75]
if config.backbone == 'mit_b2':
    config.log_dir = os.path.abspath(osp.join(config.root, 'workdirs', 'Stanford2D3DS_1024x512',
                                              'log_' + config.dataset_name + '_' + config.backbone + f'_DMLPDecoderV2_{config.fold}'))
elif config.backbone == 'dual_mit_b2' and config.modality_x[0] == 'camera-depth-1K':
    config.log_dir = os.path.abspath(osp.join(config.root, 'workdirs', 'Stanford2D3DS_1024x512',
                                              'log_' + config.dataset_name + '_' + config.backbone + f'_DMLPDecoderV2_Depth_{config.fold}'))
else:
    raise NotImplementedError
config.checkpoint_pth = os.path.join(os.path.abspath(os.path.join(config.log_dir, 'checkpoint')), 'epoch-best.pth')
config.eval_crop_size = [512, 1024]  # [height weight]


In [None]:
def process_eval_image_rgbX(image, modal_x1, modal_x2, norm_mean, norm_std):
    image = normalize(image, norm_mean, norm_std)
    image = image.transpose(2, 0, 1)
    image = np.ascontiguousarray(image[None, :, :, :], dtype=np.float32)
    image = torch.FloatTensor(image).cuda(device)

    modal_x1 = normalize(modal_x1, norm_mean, norm_std)
    modal_x1 = modal_x1.transpose(2, 0, 1)
    modal_x1 = np.ascontiguousarray(modal_x1[None, :, :, :], dtype=np.float32)
    modal_x1 = torch.FloatTensor(modal_x1).cuda(device)

    modal_x2 = normalize(modal_x2, norm_mean, norm_std)
    modal_x2 = modal_x2.transpose(2, 0, 1)
    modal_x2 = np.ascontiguousarray(modal_x2[None, :, :, :], dtype=np.float32)
    modal_x2 = torch.FloatTensor(modal_x2).cuda(device)
    
    return image, modal_x1, modal_x2


In [None]:
rgb_path = os.path.join(config.root, 'figures/pano_asmasuxybohhcj.png')
rgb = image_decoder(rgb_path, 'rgb')
rgb = np.array(rgb * 255.0, np.uint8) # (H, W, 3)

depth_path = os.path.join(config.root, 'figures/pano_asmasuxybohhcj.depth.png')
x = image_decoder(depth_path, 'i')
x = np.array(x * 255.0, np.uint8)
# ignore max depth (65535 -> 0)
x = np.where(x == 255, 0, x)
# single channel -> 3 channels
depth = cv2.merge([x, x, x]) # (H, W, 3)

image, modal_x1, modal_x2 = process_eval_image_rgbX(rgb, depth, depth, config.norm_mean, config.norm_std)

valid_labels = np.arange(config.num_classes).tolist() + [config.ignore_index]
with open(os.path.join(config.dataset_path, 'assets/colors.npy'), 'rb') as f:
    seg_colors = np.load(f)
with open(os.path.join(config.dataset_path, 'assets/name2label.json'), 'r') as f:
    name2id = json.load(f)


In [None]:
# visualize input
if config.backbone == 'mit_b2':
    fig, axs = plt.subplot_mosaic(
        [['RGB']], figsize=(15, 10), layout='constrained')
elif config.backbone == 'dual_mit_b2' and config.modality_x[0] == 'camera-depth-1K':
    fig, axs = plt.subplot_mosaic(
        [['RGB'], ['Depth']], figsize=(15, 15), layout='constrained')
else:
    raise NotImplementedError

img = image.squeeze().permute(1, 2, 0).cpu().numpy() * config.norm_std + config.norm_mean
img = (img * 255.0).astype('uint8')
plt_img1 = axs['RGB'].imshow(img)
axs['RGB'].set_axis_off()
axs['RGB'].set_title('RGB Input')

if config.backbone != 'mit_b2' and 'camera-depth-1K' in config.modality_x:
    dep = modal_x1.squeeze().permute(1, 2, 0).cpu().numpy() * config.norm_std + config.norm_mean
    # dep = (dep * 255.0).astype('uint8')
    # dep = np.where(dep[..., 0] == 255, 0.0, dep[..., 0] / 255.0) * 128.0
    dep = np.where(dep[..., 0] == 1.0, 0.0, dep[..., 0]) * 128.0
    plt_img2 = axs['Depth'].imshow(dep, vmin=0, vmax=10, cmap='jet')
    axs['Depth'].set_axis_off()
    axs['Depth'].set_title('Depth Input')

plt.show()


In [None]:
# create network
network = EncoderDecoder(cfg=config, criterion=None, norm_layer=nn.BatchNorm2d)
model = load_model(network, config.checkpoint_pth).to(device)

# redirect stdout
with redirect_stdout(trap):
    model.eval()


In [None]:
# predict
assert list(image.shape[-2:]) == config.eval_crop_size
assert list(modal_x1.shape[-2:]) == config.eval_crop_size
assert list(modal_x2.shape[-2:]) == config.eval_crop_size
with torch.no_grad():
    if config.backbone == 'mit_b2' or len(config.modality_x) == 1:
        score = model.forward(image, modal_x1)
    elif len(config.modality_x) == 2:
        score = model.forward(image, modal_x1, modal_x2)
    else:
        raise NotImplementedError
output = torch.exp(score)


In [None]:
# visualize prediction

fig, axs = plt.subplot_mosaic([['Pred']], figsize=(15, 10), layout='constrained')

predict = torch.argmax(output.long(), dim=1) + 1
pred = predict.squeeze().cpu().numpy().astype('uint8')
unlabeled = np.array(config.ignore_index + 1).astype(np.uint8)
if config.backbone == 'mit_b2':
    pred[img.sum(-1) == 0] = unlabeled  # mask as unknown id: 0
# pred[gt == unlabeled] = unlabeled  # mask as unknown id
axs['Pred'].imshow(seg_colors[pred])
axs['Pred'].set_axis_off()
axs['Pred'].set_title('Semantic Prediction')

patches = [
    mpatches.Patch(color=seg_colors[seg_val]/255., label=seg_lbl)
    for seg_lbl, seg_val in name2id.items()
]

plt.legend(handles=patches, loc='lower center', ncol=7)

plt.show()
