In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
import sys
from torchvision.transforms import CenterCrop
import torch.nn.functional as F
import numpy as np

# Add the parent directory to path so we can import the modules
sys.path.append('..')

# Import the model classes and data modules from the project
from training.mask2former_semantic import Mask2formerSemantic
from datasets.ade20k import ADE20K
from models.mask2former_decoder import  ModifiedMask2formerDecoder

In [None]:
# Load your trained model from checkpoint

checkpoint_path = "/home/valentin/workspaces/benchmark-vfm-ss/data/lightning_logs/syop5eg0/checkpoints/epoch=31-step=40000.ckpt"

# For Mask2Former model (if you trained with mask2former config)
model = Mask2formerSemantic.load_from_checkpoint(
    checkpoint_path,
    # You need to provide the network again since it's not saved in the checkpoint
    network=ModifiedMask2formerDecoder(
        img_size=(512, 512),
        num_classes=150,  # ADE20K has 150 classes
        encoder_name="vit_base_patch14_dinov2",  # Or your encoder
    ),
    strict=False  # In case of minor mismatches
)

# Set to evaluation mode
model.eval()
print(f"Model loaded from {checkpoint_path}")
print(f"Model type: {type(model).__name__}")

In [None]:
def move_to_device(x, device):
    if torch.is_tensor(x):
        return x.to(device, non_blocking=True)
    if isinstance(x, dict):
        return {k: move_to_device(v, device) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        return type(x)(move_to_device(v, device) for v in x)
    return x

In [None]:
# Set up data module to load validation data
data_module = ADE20K(
    root="../data",  # Adjust path to your data
    devices=1,
    num_workers=4,
    batch_size=1,
    img_size=(512, 512),
    num_classes=150,
    num_metrics=1
)

# Setup the data module
data_module.setup("test")


In [None]:

# Get a sample from validation set
val_dataloader = data_module.val_dataloader()
sample_batch = next(iter(val_dataloader))

# Extract image and target
sample_batch = move_to_device(sample_batch, model.device)
img, target = sample_batch


In [None]:
img[0].shape

In [None]:
target[0]["masks"].shape

In [None]:
def to_per_pixel_labels(target):
    masks = target["masks"]
    labels = target["labels"]
    return torch.einsum("nmk,n->mk", masks.float(), labels.float()).long()


def to_per_pixel_logit(mask_logits, class_logits, query_idx=None):
    if query_idx is not None:
        mask_logits = mask_logits[:, query_idx, :, :]
        class_logits = class_logits[:, query_idx, :]
    return torch.einsum(
        "bqhw,bqc->bchw",
        mask_logits.sigmoid(),
        class_logits.softmax(dim=-1)[..., :-1],
    )

def to_query_per_pixels(query_embeddings, pixel_embeddings):
    return torch.einsum("qc,bchw->bqhw", query_embeddings, pixel_embeddings)

In [None]:
# Run inference on the sample
with torch.no_grad():
    # Use the model's eval_step method which handles the full pipeline
    logits = model.eval_step(sample_batch, batch_idx=0, is_notebook=True)
    
    # Get predictions by taking argmax
    predictions = torch.argmax(logits[0], dim=0)
    
    print(f"Logits shape: {logits[0].shape}")
    print(f"Predictions shape: {predictions.shape}")
    print(f"Predicted classes: {torch.unique(predictions)}")
    
    # Convert to numpy for visualization
    img_np = img[0].permute(1, 2, 0).cpu().numpy()
    target_np = to_per_pixel_labels(target[0]).cpu().numpy()
    pred_np = predictions.cpu().numpy()

In [None]:
# Visualize results using the model's built-in plotting function
# This leverages the same visualization code used during training

plot_image = model.plot_semantic(
    img[0],           # Original image
    to_per_pixel_labels(target[0]),        # Ground truth segmentation 
    logits=logits[0]  # Model predictions
)

# Display the plot
plt.figure(figsize=(15, 5))
plt.imshow(plot_image)
plt.axis('off')
plt.title('Left: Original Image | Middle: Ground Truth | Right: Prediction')
plt.tight_layout()
plt.show()

In [None]:
test_image = torch.stack(img, axis=0)
test_image = CenterCrop((512, 512))(test_image)

test_target = torch.stack([to_per_pixel_labels(t) for t in target], axis=0)
test_target = CenterCrop((512, 512))(test_target)


In [None]:
test_image[0, ...].shape

In [None]:
with torch.inference_mode():
    test_output = model.network.forward_dict(test_image / 255.0)
test_output = move_to_device(test_output, "cpu")

In [None]:
mask_logits_per_layer = test_output["mask_logits_per_layer"]
class_logits_per_layer = test_output["class_logits_per_layer"]
mask_embeddings_per_layer = test_output["mask_embeddings_per_layer"]
per_pixel_embeddings = test_output["per_pixel_embeddings"]

In [None]:
mask_embeddings_per_layer[-1].squeeze().shape

In [None]:
mask_logits_per_layer[-1].shape

In [None]:
class_logits_per_layer[-1].shape

In [None]:
per_pixel_embeddings.device

In [None]:
model.network.q.weight.shape

In [None]:
query_per_pixels = to_query_per_pixels(model.network.q.weight.cpu(), per_pixel_embeddings).detach().cpu().numpy().squeeze()
mask_embeddings_per_pixels_l0 = to_query_per_pixels(mask_embeddings_per_layer[0].squeeze(), per_pixel_embeddings).detach().cpu().numpy().squeeze()
mask_embeddings_per_pixels_lf = to_query_per_pixels(mask_embeddings_per_layer[-1].squeeze(), per_pixel_embeddings).detach().cpu().numpy().squeeze()

In [None]:
n=20
indices = np.random.choice(query_per_pixels.shape[0], n, replace=False)

In [None]:
fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(6, 3*n))

for i, idx in enumerate(indices):
    axes[i, 0].imshow(query_per_pixels[idx], cmap="gray")
    axes[i, 0].set_title(f"Query {idx}")
    axes[i, 0].axis("off")

    axes[i, 1].imshow(mask_embeddings_per_pixels_l0[idx], cmap="gray")
    axes[i, 1].set_title(f"Mask embeddings L0 for query {idx}")
    axes[i, 1].axis("off")

    axes[i, 2].imshow(mask_embeddings_per_pixels_lf[idx], cmap="gray")
    axes[i, 2].set_title(f"Mask embeddings Lf for query {idx}")
    axes[i, 2].axis("off")

plt.tight_layout()
plt.show()

In [None]:
query_per_pixels.shape

In [None]:
mask_logits = F.interpolate(mask_logits_per_layer[-1], (512, 512), mode="bilinear")

In [None]:
class_logits = class_logits_per_layer[-1]

In [None]:
mask_logits.shape

In [None]:
class_logits.shape

In [None]:
query_idx = np.random.choice(class_logits.shape[1], size=100, replace=False)

In [None]:
test_logits = to_per_pixel_logit(mask_logits, class_logits, query_idx=query_idx)

In [None]:
plot_image = model.plot_semantic(
    test_image[0, ...],  # Original image
    test_target[0, ...],  # Ground truth segmentation 
    logits=test_logits[0, ...]  # Model predictions
)

# Display the plot
plt.figure(figsize=(15, 5))
plt.imshow(plot_image)
plt.axis('off')
plt.title('Left: Original Image | Middle: Ground Truth | Right: Prediction')
plt.tight_layout()
plt.show()