In [None]:
import os
import cv2
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, SequentialSampler
from multimae.utils.plot_utils import plot_predictions
from multimae.utils.train_utils import normalize_depth
from multimae.utils.datasets import build_multimae_pretraining_dataset
from multimae.utils.plot_utils import get_semseg_metadata
from multimae.tools.load_multimae import load_model
from multimae.models.multimae import pretrain_multimae_base

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

### load checkpoints

In [None]:
model_name = "semseg-clean"
model, args = load_model(model_name)
print(model.output_adapters.keys())

In [None]:
seed = 7
torch.manual_seed(seed) # change seed to resample new mask
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True

In [None]:
# configure for detectron dataset (for prediection)
flightmare_path = Path(os.environ["FLIGHTMARE_PATH"])
multimae_path = flightmare_path.parent / "vision_backbones/MultiMAE"
eval_data_path = multimae_path / "datasets/new_env/val"
pred_save_path = multimae_path / "results/predictions" / model_name
os.makedirs(pred_save_path, exist_ok=True)
metadata = get_semseg_metadata(eval_data_path)

In [None]:
args.eval_data_path = str(eval_data_path)
dataset_val = build_multimae_pretraining_dataset(args, args.eval_data_path)
sampler_val = SequentialSampler(dataset_val)
data_loader_val = DataLoader(
    dataset_val, 
    sampler=sampler_val,
    batch_size=1,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=True,
)

current_img_id = 0

In [None]:
masks = {
    "rgb": torch.ones((1, 14, 14), dtype=torch.long),
    "depth": torch.ones((1, 14, 14), dtype=torch.long),
    "semseg": torch.zeros((1, 14, 14), dtype=torch.long)
}
masks = {k: torch.LongTensor(v).flatten()[None].to("cuda") for k, v in masks.items()}

In [None]:
for _ in range(100):
    
    inputs = iter(data_loader_val).next()[0]
    current_img_id += 1

    if model_name != "no-standard-depth" and "depth" in inputs:
        inputs["depth"] = normalize_depth(inputs["depth"])
    inputs = {k: v.to("cuda") for k,v in inputs.items()}

    print(inputs["depth"].min())
    print(inputs["depth"].max())

    preds, masks = model(
        inputs, 
        num_encoded_tokens=196, 
        alphas=args.alphas, 
        sample_tasks_uniformly=args.sample_tasks_uniformly,
        mask_type=args.mask_type,
        masked_rgb_gate_only=True,
        semseg_gt=inputs["semseg"],
        in_domains=args.in_domains,
        semseg_stride=4,
        mask_inputs=True,
        task_masks=masks,
    )

    preds = {domain: pred.detach().cpu() for domain, pred in preds.items()}

    fig = plot_predictions(inputs, preds, masks, metadata=metadata, return_fig=True)

    fig.save(f"{pred_save_path}/{current_img_id}.png")

    print(preds["depth"].min())
    print(preds["depth"].max())