# Inference and visualization for DiffAssemble


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image, ImageFile

In [None]:
from src.model.full_models import *
from src.full_dataset import *
from src.gnn_diffusion import GNN_Diffusion

### 0.- Define paths and steps

In [None]:
dataset_path = os.path.join(os.getcwd(), "data/CelebA-HQ")
print(dataset_path)

steps = 2

### 1.- Load test dataset


In [None]:
# Load base dataset
test_dataset_base = CelebA_HQ(dataset_path, train=False)
print(f"Test dataset length: {len(test_dataset_base)}")
print(f"Sample image: ")
plt.imshow(test_dataset_base[0])
plt.axis("off")

In [None]:
# Load puzzle dataset and sample an element
test_puzzle_dt = Puzzle_Dataset_ROT(
                        dataset=test_dataset_base,
                        patch_per_dim=[(6,6)], 
                        augment=False, 
                        degree=-1, 
                        unique_graph=None, 
                        all_equivariant=False, 
                        random_dropout=False)

elem = test_puzzle_dt[0]

print(elem)
print(f"X: {elem.x}") # This contains all the node features: x, y, rot1, rot2
print(f"EDGE_INDEX: {elem.edge_index}")
print(f"INDEXES: {elem.indexes}")
print(f"ROT: {elem.rot}")
print(f"ROT_INDEX: {elem.rot_index}")
print(f"IND_NAME: {elem.ind_name}")

### Load model with the checkpoint from training

In [None]:
#Load model
model = Eff_GAT(steps=2,
                input_channels=4,
                output_channels=4,
                n_layers=4,
                model="resnet18equiv",
                architecture="transformer")

# Load model with the checkpoint and set to evaluation mode
checkpoint = torch.load("checkpoints/eff_gat_epoch_30_steps_2_batchsize_10_puzzdim_6_6.pt",
                        weights_only=False,
                        map_location=torch.device("cpu"))
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Model parameters after loading checkpoint:")
for name, param in model.named_parameters():
    print(name, param)


### Run inference with a single sample

In [None]:
# Add test step to the GNN Diffusion class
# Modify to return noise and predicted noise AND work with batches of size = 1 in a controlled manner
class GNN_Diffusion_Inference(GNN_Diffusion):
    
    def inference_step(self, sample, model, criterion):
        sample = sample.to(self.device)

        if not sample.batch:
            sample.batch = torch.zeros(sample.x.size(0), dtype=torch.long, device=self.device)
        
        print(sample.batch)

        num_graphs = int(sample.batch.max().item()) + 1
        t_graph = torch.randint(0, self.steps, (num_graphs,), device=self.device).long()
        t = torch.gather(t_graph, 0, sample.batch)

        x_start = sample.x
        noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        patch_feats = model.visual_features(sample.patches)
        prediction, _ = model.forward_with_feats(
            x_noisy, t, sample.patches, sample.edge_index, patch_feats, sample.batch
        )
        return criterion(noise, prediction)


In [None]:
# Dataloader
test_loader = torch.utils.data.DataLoader(test_puzzle_dt, batch_size=1, shuffle=False)
sample = test_puzzle_dt[0]
print(f"X: {sample.x}") # This contains all the node features: x, y, rot1, rot2
print(f"EDGE_INDEX: {sample.edge_index}")
print(f"INDEXES: {sample.indexes}")
print(f"ROT: {sample.rot}")
print(f"ROT_INDEX: {sample.rot_index}")
print(f"IND_NAME: {sample.ind_name}")


In [None]:
# Load criterion, inference GNN Diffusion class, and run inference step
criterion = torch.nn.functional.smooth_l1_loss
gnn_diffusion = GNN_Diffusion_Inference(steps=steps)

with torch.no_grad():
    #This only returns loss, we need to modify to return noise and predicted noise for visualization
    loss = gnn_diffusion.inference_step(sample, model, criterion)


In [None]:
loss