In [None]:
import os
import cv2
import torch
import random
import numpy as np
from tosem.io.image import read_rgb
from tosem.io import load_config
from tosem.transform import Transform
from tosem import create_model
from tosem.dataset import SegmentationDataset
from tosem.utils import plot_predictions

### **Checkpoints and config from experiments**

In [None]:
ckpt_path = ...
config_path = ...

In [None]:
config = load_config(config_path)

In [None]:
model = create_model(
    **config["model"],
    ckpt_path=ckpt_path
)

In [None]:
model.eval();

### **Dataset**

In [None]:
data_dir = "/Users/riccardomusmeci/Developer/data/github/smart-arrotino/lyft/split"

In [None]:
dataset = SegmentationDataset(
    data_dir=data_dir,
    train=False,
    transform=Transform(
        train=False,
        **config["transform"]
    ),
    class_channel=0
)

## **Inference + plot on random images**

In [None]:
# Random selection
threshold=.5
index = random.randint(0, len(dataset))
img_path = os.path.join(data_dir, "val", "images", dataset.images[index])
print(f"Image index {index} - Image: {img_path}")

# Model Inference
with torch.no_grad():
    x, mask = dataset[index]
    x = x.unsqueeze(0)
    logits = model(x)
    if config["model"]["num_classes"] == 2:
        preds = torch.sigmoid(logits)
    else:
        preds = torch.softmax(logits, dim=1)
    preds = torch.argmax(preds, dim=1, keepdim=True)

# Open image
img = read_rgb(img_path)
img = cv2.resize(img, config["transform"]["input_size"])
mask = mask.squeeze().numpy()

# Prediction mask filtered
pred_mask = preds.squeeze().numpy()


In [None]:
plot_predictions(
    image=img,
    gt_mask=mask,
    mask=pred_mask,
    alpha=1
)