# LSeg Demo

First we initialize any required visualization utility functions.

In [None]:
# Import Libraries
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from matplotlib.patches import Patch

from lseg_train import LSegModule
from Lseg.lseg_net import LSegNet
from Lseg.data.util import get_labels, get_dataset

# Utility Functions
def generate_color_map(num_classes):
    """
    Generate a color map for visualizing segmentation masks.
    Args:
        num_classes: Total number of classes.
    Returns:
        color_map: A dictionary mapping class indices to RGB tuples.
    """
    np.random.seed(0)  # Ensures consistent colors across runs
    color_map = {i: np.random.randint(0, 255, size=3) for i in range(num_classes)}
    color_map[num_classes - 1] = np.array([0, 0, 0])  # Black for background
    return color_map

def apply_color_map(label_tensor, color_map):
    """
    Map a label tensor to an RGB image using the provided color map.
    Args:
        label_tensor: A torch tensor of shape [H, W] containing class labels.
        color_map: A dictionary mapping class indices to RGB tuples.
    Returns:
        rgb_image: An RGB image of shape [H, W, 3].
    """
    label_array = label_tensor.numpy()
    h, w = label_array.shape
    rgb_image = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        rgb_image[label_array == class_idx] = color
    return rgb_image

def visualize_predictions(predictions, color_map, labels, title="Segmentation"):
    """
    Visualize segmentation predictions with a legend for labels.
    Args:
        predictions: A torch tensor of shape [H, W].
        color_map: A dictionary mapping class indices to RGB tuples.
        labels: A list of class label names.
        title: Title of the plot.
    """
    # Apply color map to the predictions
    rgb_image = apply_color_map(predictions, color_map)

    # Create the figure
    plt.figure(figsize=(10, 6))
    plt.imshow(rgb_image)
    plt.title(title)
    plt.axis("off")

    # Add a legend
    legend_elements = [
        Patch(facecolor=np.array(color) / 255.0, label=labels[idx]) 
        for idx, color in color_map.items()
    ]
    plt.legend(handles=legend_elements, loc="upper right", bbox_to_anchor=(1.2, 1), title="Labels", fontsize="small")
    plt.show()

### Model Initialization

We load the pre-trained weights for the model from the pytorch lightning checkpoints.

In [None]:
test_dataset = get_dataset(dataset_name="ade20k", get_train=False) # Replace with dataset loader

config = {
    "batch_size": 12,  # 6
    "base_lr": 0.04,
    "max_epochs": 50,
    "num_features": 512,
}

test_dataloaders = DataLoader(
    test_dataset, batch_size=config["batch_size"], shuffle=False, pin_memory=True
)

# Dummy Labels (Replace with your actual labels)
labels = get_labels()

net = LSegNet(
    labels=labels,
    features=config["num_features"],
)

# Load Model - replace with actual
load_checkpoint_path = r"checkpoints/checkpoint_epoch=3-val_loss=4.9235"

load_model = LSegModule.load_from_checkpoint(load_checkpoint_path,
                                        max_epochs=config["max_epochs"],
                                        model=net,
                                        num_classes=len(labels),
                                        batch_size=config["batch_size"],
                                        base_lr=config["base_lr"],
                                        )
load_model.eval()

### Visualization

Here we test the visualization of the generated masks and the label that was assigned to them. We visualize the ground truth segmentation masks and the model's predictions for a few test samples. The masks are displayed using the color map generated earlier, and each class is associated with a specific color and label in the legend.

In [None]:
color_map = generate_color_map(len(labels))

# Visualize Predictions
for i in range(3):  # Visualize first 3 samples
    img, label = test_dataset[i]
    img_batch = img.unsqueeze(0)  # Add batch dimension
    label_batch = torch.argmax(label, dim=2) # Adjust dim param according to the label shape

    # Ground Truth
    visualize_predictions(label_batch, color_map, labels, title="Ground Truth")

    # Model Prediction
    with torch.no_grad():
        output = load_model(img_batch)
        prediction = torch.argmax(output.squeeze(0), dim=0)
        visualize_predictions(prediction, color_map, labels, title="Model Prediction")