# Installation

âš  Install the required packages only if this notebook runs in Colab. Otherwise you should install the required packages manually on your local python environment.

In [None]:
import sys

IN_COLAB = "google.colab" in sys.modules
print(IN_COLAB)

In [None]:
!python --version

In [None]:
# Set this to True if this notebook runs in Colab and GPU is available.
# Ignore it if you're not in Colab

COLAB_CAN_USE_GPU=True

Install the version 2.6.0 of torch version to be able to install later compatible pytorch-geometric packages

In [None]:
if IN_COLAB:
  !pip uninstall -y torch torchvision torchaudio
  if COLAB_CAN_USE_GPU:
    !pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
  else:
    !pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu

In [None]:
import torch

if IN_COLAB:
  torch_version = torch.__version__.split('+')[0]
  if COLAB_CAN_USE_GPU:
    cuda_version = torch.version.cuda.replace('.', '')
    !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-{torch_version}+cu{cuda_version}.html

  else:
    !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-{torch_version}+cpu.html

  !pip install -q torch-geometric

In [None]:
if IN_COLAB:
    !pip install wandb

## Wandb login

In [None]:
!wandb login

## Inspect runtime default versions and settings

Check torch and torchvision default versions. For now we are just going to use them, we'll change them if we hit any conflict in the future.

In [None]:
import torch
import torchvision

print(f"Torch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print("")
print(f"Torch cuda is available: {torch.cuda.is_available()}")

If cuda is not available, enable GPU in Colab by going to 'Runtime' > 'Change runtime type' > Select 'T4 GPU'.

This will restart the session and you'll need to rerun all the cells again. After restarting the session, verify that cuda is available.

### Nvidia version

The following command (nvidia-smi) will tell you which GPU you are using (if any).

In [None]:
!nvidia-smi

## Enable cuda if available

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Main imports

In [None]:
import os
from pathlib import Path
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
import numpy as np
import einops

import torch
import torch_geometric
import torchvision

# Dataset download

Download the dataset only if this notebook runs in Colab, otherwise you'll need to download it manually.

In [None]:
import os
os.environ["IN_COLAB"] = "1" if IN_COLAB else "0"

In [None]:
import os

if IN_COLAB:
    DATASET_IMG_PATH="data/CelebA-HQ/images"
    DATASET_TXT_PATH="data/CelebA-HQ"
else:
    DATASET_IMG_PATH="../data/CelebA-HQ/images"
    DATASET_TXT_PATH="../data/CelebA-HQ"

os.environ["DATASET_IMG_PATH"] = DATASET_IMG_PATH
os.environ["DATASET_TXT_PATH"] = DATASET_TXT_PATH

In [None]:
%%bash

if [ "$IN_COLAB" == "0" ]; then
  echo "Skipping download (IN_COLAB is false)"
  exit 0
fi

echo 'Downloading dataset...'

OUTPUT_FILENAME='dataset.zip'
FOLDER_NAME='CelebAMask-HQ'

mkdir -p ${DATASET_IMG_PATH}

if [ -d ${DATASET_IMG_PATH}/${FOLDER_NAME} ]; then
  echo "Skipping the download since the folder ${DATASET_IMG_PATH}/${FOLDER_NAME} already exists"
  exit 0
fi

rm ${OUTPUT_FILENAME}
rm -r ${FOLDER_NAME}
wget --no-check-certificate 'https://huggingface.co/datasets/liusq/CelebAMask-HQ/resolve/main/CelebAMask-HQ.zip?download=true' -O ${OUTPUT_FILENAME}
echo "${OUTPUT_FILENAME} downloaded. Unziping it..."
unzip ${OUTPUT_FILENAME}
rm ${OUTPUT_FILENAME}

mv ${FOLDER_NAME} ${DATASET_IMG_PATH}

echo "Done"

Preview an image

In [None]:
img = Image.open(DATASET_IMG_PATH + "/CelebAMask-HQ/CelebA-HQ-img/1000.jpg")
plt.imshow(img)
plt.axis("off")

Download the txt files from the DiffAssemble repository that define the data split between training and testing

In [None]:
%%bash

if [ "$IN_COLAB" == "0" ]; then
  echo "Skipping download (IN_COLAB is false)"
  exit 0
fi

[ -f CelebA-HQ_test.txt ] && rm CelebA-HQ_test.txt
[ -f CelebA-HQ_train.txt ] && rm CelebA-HQ_train.txt

wget -q https://raw.githubusercontent.com/IIT-PAVIS/DiffAssemble/refs/heads/release/datasets/data_splits/CelebA-HQ_test.txt
wget -q https://raw.githubusercontent.com/IIT-PAVIS/DiffAssemble/refs/heads/release/datasets/data_splits/CelebA-HQ_train.txt

mkdir -p $DATASET_TXT_PATH
mv CelebA-HQ_test.txt $DATASET_TXT_PATH
mv CelebA-HQ_train.txt $DATASET_TXT_PATH

ls $DATASET_TXT_PATH

# DataSet implementation

Define a basic Dataset class the solely loads images from the disk


In [None]:
class CelebA_DataSet(torch.utils.data.Dataset):
    """
    ONLY loads images.
    No patches. No graphs. No diffusion logic.
    """

    def __init__(self, train=True, transform=None):
        self.images_path = DATASET_IMG_PATH + "/CelebAMask-HQ/CelebA-HQ-img/"
        if train:
            txt_path = DATASET_TXT_PATH + "/CelebA-HQ_train.txt"
        else:
            txt_path = DATASET_TXT_PATH + "/CelebA-HQ_test.txt"

        self.image_names = []
        with open(txt_path, "r", encoding="utf-8") as f:
            self.image_names = f.read().splitlines()

        self.transform = transform

    def __len__(self):
        return len(self.image_names)
        #return 50

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.images_path, self.image_names[idx]))
        if self.transform:
            img = self.transform(img)
        return img

Test the Dataset

In [None]:
dataset = CelebA_DataSet()
img = dataset[0]
plt.imshow(img)
plt.axis("off")


Define a more complex Dataset on top of the previous one. This class will split the image in patches and return graph ready to be trained.

In [None]:
import math
import random
from typing import List, Tuple
from scipy.sparse.linalg import eigsh
from PIL.Image import Resampling

from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F

def generate_random_expander(num_nodes, degree, rng=None, max_num_iters=5, exp_index=0):
    """Generates a random d-regular expander graph with n nodes.
    Returns the list of edges. This list is symmetric; i.e., if
    (x, y) is an edge so is (y,x).
    Args:
      num_nodes: Number of nodes in the desired graph.
      degree: Desired degree.
      rng: random number generator
      max_num_iters: maximum number of iterations
    Returns:
      senders: tail of each edge.
      receivers: head of each edge.
    """
    if isinstance(degree, str):
        degree = round((int(degree[:-1]) * (num_nodes - 1)) / 100)
    num_nodes = num_nodes

    if rng is None:
        rng = np.random.default_rng()
    eig_val = -1
    eig_val_lower_bound = (
        max(0, degree - 2 * math.sqrt(degree - 1) - 0.1) if degree > 0 else 0
    )  # allow the use of zero degree

    max_eig_val_so_far = -1
    max_senders = []
    max_receivers = []
    cur_iter = 1

    # (bave): This is a hack.  This should hopefully fix the bug
    if num_nodes <= degree:
        degree = num_nodes - 1

    # (ali): if there are too few nodes, random graph generation will fail. in this case, we will
    # add the whole graph.
    if num_nodes <= 10:
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:
                    max_senders.append(i)
                    max_receivers.append(j)
    else:
        while eig_val < eig_val_lower_bound and cur_iter <= max_num_iters:
            senders, receivers = generate_random_regular_graph(num_nodes, degree, rng)

            eig_val = get_eigenvalue(senders, receivers, num_nodes=num_nodes)
            if len(eig_val) == 0:
                print(
                    "num_nodes = %d, degree = %d, cur_iter = %d, mmax_iters = %d, senders = %d, receivers = %d"
                    % (
                        num_nodes,
                        degree,
                        cur_iter,
                        max_num_iters,
                        len(senders),
                        len(receivers),
                    )
                )
                eig_val = 0
            else:
                eig_val = eig_val[0]
            if eig_val > max_eig_val_so_far:
                max_eig_val_so_far = eig_val
                max_senders = senders
                max_receivers = receivers

            cur_iter += 1
    max_senders = torch.tensor(max_senders, dtype=torch.long).view(-1, 1)
    max_receivers = torch.tensor(max_receivers, dtype=torch.long).view(-1, 1)
    expander_edges = torch.cat([max_senders, max_receivers], dim=1)
    return expander_edges


def get_eigenvalue(senders, receivers, num_nodes):
    edge_index = torch.tensor(np.stack([senders, receivers]))
    edge_index, edge_weight = torch_geometric.utils.get_laplacian(
        edge_index, None, normalization=None, num_nodes=num_nodes
    )
    L = torch_geometric.utils.to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)
    return eigsh(L, k=2, which="SM", return_eigenvectors=False)


def generate_random_regular_graph(num_nodes, degree, rng=None):
    """Generates a random d-regular connected graph with n nodes.
    Returns the list of edges. This list is symmetric; i.e., if
    (x, y) is an edge so is (y,x).
    Args:
      num_nodes: Number of nodes in the desired graph.
      degree: Desired degree.
      rng: random number generator
    Returns:
      senders: tail of each edge.
      receivers: head of each edge.
    """
    if (num_nodes * degree) % 2 != 0:
        raise TypeError("nodes * degree must be even")
    if rng is None:
        rng = np.random.default_rng()
    if degree == 0:
        return np.array([]), np.array([])
    nodes = rng.permutation(np.arange(num_nodes))
    num_reps = degree // 2
    num_nodes = len(nodes)

    ns = np.hstack([np.roll(nodes, i + 1) for i in range(num_reps)])
    edge_index = np.vstack((np.tile(nodes, num_reps), ns))

    if degree % 2 == 0:
        senders, receivers = np.concatenate(
            [edge_index[0], edge_index[1]]
        ), np.concatenate([edge_index[1], edge_index[0]])
        return senders, receivers
    else:
        edge_index = np.hstack(
            (edge_index, np.vstack((nodes[: num_nodes // 2], nodes[num_nodes // 2 :])))
        )
        senders, receivers = np.concatenate(
            [edge_index[0], edge_index[1]]
        ), np.concatenate([edge_index[1], edge_index[0]])
        return senders, receivers

class RandomCropAndResizedToOriginal(torchvision.transforms.RandomResizedCrop):
    def forward(self, img):
        size = img.size
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
        return F.resized_crop(img, i, j, h, w, size, self.interpolation)

def _get_augmentation(augmentation_type: str = "none"):
    switch = {
        "weak": [torchvision.transforms.RandomHorizontalFlip(p=0.5)],
        "hard": [
            torchvision.transforms.RandomHorizontalFlip(p=0.5),
            RandomCropAndResizedToOriginal(
                size=(1, 1), scale=(0.8, 1), interpolation=InterpolationMode.BICUBIC
            ),
        ],
    }
    return switch.get(augmentation_type, [])

def divide_images_into_patches(
    img, patch_per_dim: List[int], patch_size: int
) -> List[torch.Tensor]:
    # img2 = einops.rearrange(img, "c h w -> h w c")

    # divide images in non-overlapping patches based on patch size
    # output dim -> a
    img2 = img.permute(1, 2, 0)
    patches = img2.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
    y = torch.linspace(-1, 1, patch_per_dim[0])
    x = torch.linspace(-1, 1, patch_per_dim[1])
    xy = torch.stack(torch.meshgrid(x, y, indexing="xy"), -1)
    # print(patch_per_dim)

    return xy, patches


# generation of a unique graph for each number of nodes
def create_graph(patch_per_dim, degree, unique_graph):
    # Create an empty dictionary
    patch_edge_index_dict = {}
    for patch_dim in patch_per_dim:
        if degree == -1:
            num_patches = patch_dim[0] * patch_dim[1]
            adj_mat = torch.ones(num_patches, num_patches)
            edge_index, _ = adj_mat.nonzero().t().contiguous()
        else:
            num_patches = patch_dim[0] * patch_dim[1]
            edge_index = (
                generate_random_expander(
                    num_nodes=num_patches, degree=degree, rng=unique_graph
                )
                .t()
                .contiguous()
            )
        patch_edge_index_dict[patch_dim] = edge_index
    return patch_edge_index_dict

class Puzzle_Dataset(torch_geometric.data.Dataset):
    def __init__(
        self,
        dataset=None,
        patch_per_dim=[(7, 6)],
        patch_size=32,
        augment="",
        degree=-1,
        unique_graph=None,
        random=False,
    ) -> None:
        super().__init__()

        self.dataset = dataset
        self.patch_per_dim = patch_per_dim
        self.unique_graph = unique_graph
        self.augment = augment
        self.random = random

        self.transforms = torchvision.transforms.Compose(
            [
                *_get_augmentation(augment),
                torchvision.transforms.ToTensor(),
            ]
        )
        self.patch_size = patch_size
        self.degree = degree

        if self.unique_graph is not None:
            self.edge_index = create_graph(
                self.patch_per_dim, self.degree, self.unique_graph
            )

    def len(self) -> int:
        if self.dataset is not None:
            return len(self.dataset)
        else:
            raise Exception("Dataset not provided")

    def get(self, idx):
        if self.dataset is not None:
            img = self.dataset[idx]

        rdim = torch.randint(len(self.patch_per_dim), size=(1,)).item()
        patch_per_dim = self.patch_per_dim[rdim]

        height = patch_per_dim[0] * self.patch_size
        width = patch_per_dim[1] * self.patch_size
        img = img.resize((width, height))  # , resample=Resampling.BICUBIC)
        img = self.transforms(img)

        xy, patches = divide_images_into_patches(img, patch_per_dim, self.patch_size)

        xy = einops.rearrange(xy, "x y c -> (x y) c")

        indexes = torch.arange(patch_per_dim[0] * patch_per_dim[1]).reshape(
            xy.shape[:-1]
        )
        patches = einops.rearrange(patches, "x y c k1 k2 -> (x y) c k1 k2")
        if self.random:
            patches = patches[torch.randperm(len(patches))]
        if self.degree == -1:
            # all connected to all
            adj_mat = torch.ones(
                patch_per_dim[0] * patch_per_dim[1], patch_per_dim[0] * patch_per_dim[1]
            )

            edge_index, _ = torch_geometric.utils.dense_to_sparse(adj_mat)
        else:
            if not self.unique_graph:
                edge_index = generate_random_expander(
                    patch_per_dim[0] * patch_per_dim[1], self.degree
                ).T
        data = torch_geometric.data.Data(
            x=xy,
            indexes=indexes,
            patches=patches,
            edge_index=(
                self.edge_index[patch_per_dim] if self.unique_graph else edge_index
            ),
            ind_name=torch.tensor([idx]).long(),
            patches_dim=torch.tensor([patch_per_dim]),
        )
        return data

class Puzzle_Dataset_ROT(Puzzle_Dataset):
    def __init__(
        self,
        dataset=None,
        patch_per_dim=[(7, 6)],
        patch_size=32,
        augment=False,
        concat_rot=True,
        degree=-1,
        unique_graph=None,
        all_equivariant=False,
        random_dropout=False,
    ) -> None:
        super().__init__(
            dataset=dataset,
            patch_per_dim=patch_per_dim,
            patch_size=patch_size,
            augment=augment,
            degree=degree,
            unique_graph=unique_graph,
        )
        self.concat_rot = concat_rot
        self.degree = degree
        self.all_equivariant = all_equivariant
        self.unique_graph = unique_graph
        self.random_dropout = random_dropout
        if self.unique_graph is not None:
            self.edge_index = create_graph(
                self.patch_per_dim, self.degree, self.unique_graph
            )

    def get(self, idx):
        if self.dataset is not None:
            img = self.dataset[idx]

        rdim = torch.randint(len(self.patch_per_dim), size=(1,)).item()
        patch_per_dim = self.patch_per_dim[rdim]

        height = patch_per_dim[0] * self.patch_size
        width = patch_per_dim[1] * self.patch_size

        img = img.resize(
            (width, height), resample=Resampling.LANCZOS
        )  # , resample=Resampling.BICUBIC)

        img = self.transforms(img)
        xy, patches = divide_images_into_patches(img, patch_per_dim, self.patch_size)

        xy = einops.rearrange(xy, "x y c -> (x y) c")
        patches = einops.rearrange(patches, "x y c k1 k2 -> (x y) c k1 k2")

        patches_num = patches.shape[0]

        patches_numpy = (
            (patches * 255).long().numpy().transpose(0, 2, 3, 1).astype(np.uint8)
        )
        patches_im = [Image.fromarray(patches_numpy[x]) for x in range(patches_num)]
        random_rot = torch.randint(low=0, high=4, size=(patches_num,))
        random_rot_one_hot = torch.nn.functional.one_hot(random_rot, 4)

        # if self.degree == '100%':

        if self.degree == -1 or self.degree == "100%":
            adj_mat = torch.ones(
                patch_per_dim[0] * patch_per_dim[1], patch_per_dim[0] * patch_per_dim[1]
            )

            edge_index, _ = torch_geometric.utils.dense_to_sparse(adj_mat)
        elif self.random_dropout:
            adj_mat = torch.ones(
                patch_per_dim[0] * patch_per_dim[1], patch_per_dim[0] * patch_per_dim[1]
            )

            edge_index, _ = torch_geometric.utils.dense_to_sparse(adj_mat)
            degree = round(
                (int(self.degree[:-1]) * (int(patch_per_dim[0] * patch_per_dim[1]) - 1))
                / 100
            )
            n_connections = int(patch_per_dim[0] * patch_per_dim[1] * degree)
            edge_index = edge_index[:, torch.randperm(edge_index.shape[1])][
                :, :n_connections
            ]

        else:
            if not self.unique_graph:
                edge_index = generate_random_expander(
                    patch_per_dim[0] * patch_per_dim[1], self.degree
                ).T

        # rotation classes : 0 -> no rotation
        #                   1 -> 90 degrees
        #                   2 -> 180 degrees
        #                   3 -> 270 degrees

        indexes = torch.arange(patch_per_dim[0] * patch_per_dim[1]).reshape(
            xy.shape[:-1]
        )

        rots = torch.tensor(
            [
                [1, 0],
                [0, 1],
                [-1, 0],
                [0, -1],
            ]
        )

        rots_tensor = random_rot_one_hot @ rots

        # ruoto l'immagine casualmente

        rotated_patch = [
            x.rotate(rot * 90) for (x, rot) in zip(patches_im, random_rot)
        ]  # in PIL

        if self.all_equivariant:
            rotated_patch_1 = [
                [x.rotate(rot * 90) for rot in range(4)] for x in rotated_patch
            ]  # type: ignore

            rotated_patch_tensor = [
                [
                    torch.tensor(np.array(patch)).permute(2, 0, 1).float() / 255
                    for patch in test
                ]
                for test in rotated_patch_1
            ]
        else:
            rotated_patch_tensor = [
                torch.tensor(np.array(patch)).permute(2, 0, 1).float() / 255
                for patch in rotated_patch
            ]

        patches = (
            torch.stack([torch.stack(i) for i in rotated_patch_tensor])
            if self.all_equivariant
            else torch.stack(rotated_patch_tensor)
        )
        if self.concat_rot:
            xy = torch.cat([xy, rots_tensor], 1)

        data = torch_geometric.data.Data(
            x=xy,
            indexes=indexes,
            rot=rots_tensor,
            rot_index=random_rot,
            patches=patches,
            edge_index=(
                self.edge_index[patch_per_dim] if self.unique_graph else edge_index
            ),
            ind_name=torch.tensor([idx]).long(),
            patches_dim=torch.tensor([patch_per_dim]),
        )
        return data

Inspect the output of Puzzle_Dataset_ROT

###Interesting points

- **Number of patches**. We'll see the image has been splited accordingly with the value assigned to the 'patch_per_dim' parameter. For instance, if patch_per_dim is [(6,6)], we'll see 36 patches per image.

- **Rotation**.
TODO

In [None]:
train_dt = CelebA_DataSet(train=True)

puzzle_dt = Puzzle_Dataset_ROT(dataset=train_dt,patch_per_dim=[(6,6)], augment=False, degree=-1, unique_graph=None, all_equivariant=False, random_dropout=False)

In [None]:
elem=puzzle_dt[0]

print(elem)
print(f"X: {elem.x}")
print(f"EDGE_INDEX: {elem.edge_index}")
print(f"INDEXES: {elem.indexes}")
print(f"ROT: {elem.rot}")
print(f"ROT_INDEX: {elem.rot_index}")
print(f"IND_NAME: {elem.ind_name}")

In [None]:
# Print original image
idx = 0

plt.imshow(puzzle_dt.dataset[idx])
plt.axis("off")

In [None]:
from torchvision.utils import make_grid

graph=puzzle_dt[idx]

# rotIdx=3
# patches = graph.patches[:, rotIdx]

grid = make_grid(graph.patches, nrow=6, padding=2)

# Convert CHW -> HWC for matplotlib
grid = grid.permute(1, 2, 0)

plt.figure(figsize=(12, 12))
plt.imshow(grid)
plt.axis("off")
plt.show()

In [None]:
from torchvision.utils import make_grid

graph=puzzle_dt[idx]

grid = make_grid(graph.patches, nrow=6, padding=2)

# Convert CHW -> HWC for matplotlib
grid = grid.permute(1, 2, 0)

plt.figure(figsize=(12, 12))
plt.imshow(grid)
plt.axis("off")
plt.show()

# Backbone model

If we are in Colab, let's import the backbone model directly from our repository.
The important elements of the imported code are:
- Eff_GAT class. It's the entire nn.Module that will be used to train the model.
- ResNet18 class. Is the inner module that Eff_GAT will use when it's instanciated with a "resnet18equiv" model.

If we are not in Colab, we'll work with the srd/model folder of your current repository.

In [None]:
%%bash

if [ "$IN_COLAB" == "0" ]; then
  echo "Skipping download (IN_COLAB is false)"
  exit 0
fi

REPO_DIR="deep-learning-puzzle-project"
MODEL_DIR="model"

rm -r ${MODEL_DIR}

git clone https://github.com/silviasuhu/deep-learning-puzzle-project.git
cd ${REPO_DIR}
git checkout b586dd709a8ce46465aa0284bd8d5eac812a8c94
cd ..

mv ${REPO_DIR}/src/model ${MODEL_DIR}
rm -rf ${REPO_DIR}

In [None]:
# Let's tell to Colab that we may need to import python packages from the currect directory (which is '/content')
import sys

if IN_COLAB:
  sys.path.append('/content')
else:
  sys.path.append('../src')

# Training

Let's inspect first the Dataloader

In [None]:
dataset = Puzzle_Dataset_ROT(dataset=train_dt, patch_per_dim=[(6,6)], augment=False, degree=-1, unique_graph=None, all_equivariant=True, random_dropout=False)

BATCH_SIZE=10
dataloader = torch_geometric.loader.DataLoader(
  dataset, batch_size=BATCH_SIZE, shuffle=True
)

first_batch = next(iter(dataloader))

# Let's compare the dataset structure with the dataloader batch structure
print(dataset[0])
print(first_batch)

# As you'll see, the first dimension of each parameter has been multiplied by the batch_size.

# x contains...
# edge_index contains...
# indexes contains...
# rot contains...
# rot_index contains...
# patches contains the image patches rotated 0,90,180 or 270 degrees
# ind_name contains...
# patches_dim contains the number of patches in the x and in the y axis.
# batch contains...


Let's start defining a function that will run at every training iteration:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def extract(a, t):
    out = a.gather(-1, t)
    return out[:, None]


def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


class GNN_Diffusion:
    def __init__(self, steps, device):
        self.steps = steps
        self.device = device

        # diffusion schedule (CREATE DIRECTLY ON DEVICE)
        betas = linear_beta_schedule(steps).to(device)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
            1.0 - alphas_cumprod
        )

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(
            self.sqrt_alphas_cumprod, t
        )
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t
        )

        return (
            sqrt_alphas_cumprod_t * x_start +
            sqrt_one_minus_alphas_cumprod_t * noise
        )

    def training_step(self, batch, model, criterion, optimizer):
        optimizer.zero_grad()

        batch_size = batch.batch.max().item() + 1

        # t ON SAME DEVICE
        t = torch.randint(
            0, self.steps, (batch_size,),
            device=self.device
        ).long()

        t = torch.gather(t, 0, batch.batch)

        x_start = batch.x
        noise = torch.randn_like(x_start)

        x_noisy = self.q_sample(x_start, t, noise)

        patch_feats = model.visual_features(batch.patches)

        prediction, _ = model.forward_with_feats(
            x_noisy,
            t,
            batch.patches,
            batch.edge_index,
            patch_feats,
            batch.batch
        )

        loss = criterion(noise, prediction)
        loss.backward()
        optimizer.step()

        return loss.item()


In [None]:
from model.efficient_gat import Eff_GAT
from transformers.optimization import Adafactor
from torch.utils.data import random_split
from pathlib import Path
import wandb
import os


def train_model(batch_size,
                steps,
                epochs,
                patch_per_dim=[(6, 6)],
                use_scheduler=False):

    # Start a new wandb run to track this script.
    run = wandb.init(
        entity="postgraduate-project-puzzle-upc",
        project="training-10-Feb-01",
        # Track hyperparameters and run metadata.
        config={
            "batch_size": batch_size,
            "steps": steps,
            "epochs": epochs,
            "patch_per_dim": patch_per_dim,
            "model": "Eff_gat",
            "optimizer": "Adafactor",
            "loss": "smooth_l1",
            "use_scheduler": use_scheduler,
        })

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Eff_GAT(steps=steps,
                    input_channels=4,
                    output_channels=4,
                    n_layers=4,
                    model="resnet18equiv")
    model.to(device)

    # track gradient norms, parameters updates, detect exploding gradients early
    wandb.watch(model, log="gradients", log_freq=100)

    criterion = torch.nn.functional.smooth_l1_loss
    optimizer = Adafactor(model.parameters())

    # Optional Scheduler
    scheduler = None
    if use_scheduler:
        # Example: linear decay scheduler
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=10,
                                                    gamma=0.9)

    # train_dt = CelebA_DataSet(train=True) #< This is not necessary since we set train_dt in a previuos cell
    full_dataset = Puzzle_Dataset_ROT(dataset=train_dt,
                                      patch_per_dim=patch_per_dim,
                                      augment=False,
                                      degree=-1,
                                      unique_graph=None,
                                      all_equivariant=False,
                                      random_dropout=False)

    # split dataset training and validation:
    val_ratio = 0.1
    val_size = int(len(full_dataset) * val_ratio)
    train_size = len(full_dataset) - val_size

    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42))

    train_loader = torch_geometric.loader.DataLoader(train_dataset,
                                                     batch_size=batch_size,
                                                     shuffle=True)
    val_loader = torch_geometric.loader.DataLoader(val_dataset,
                                                   batch_size=batch_size)

    gnn_diffusion = GNN_Diffusion(steps=steps, device=device)

    checkpoint_dir = Path("checkpoints")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    latest_ckpt = checkpoint_dir / "latest.pt"
    start_epoch = 0

    # Resume from checkpoint if exists
    # -------------------------
    if latest_ckpt.exists():
        print(f"Resuming from checkpoint: {latest_ckpt}")
        checkpoint = torch.load(latest_ckpt,
                                map_location=device,
                                weights_only=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        if use_scheduler and "scheduler_state_dict" in checkpoint and scheduler is not None:
            scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"]
        print(f"Resuming at epoch {start_epoch + 1}")

    for epoch in range(start_epoch, epochs):

        # TRAINING

        model.train()
        train_losses = []
        for batch in train_loader:
            batch = batch.to(device)
            loss = gnn_diffusion.training_step(batch, model, criterion,
                                               optimizer)
            train_losses.append(loss)
        train_loss = np.mean(train_losses)

        # VALIDATION

        # switch model to evaluation mode
        model.eval()
        val_losses = []

        # disable gradient tracking (save memory,prevents accidental backprop)
        with torch.no_grad():
            for batch in val_loader:
                # move batch to GPU/CPU
                batch = batch.to(device)
                batch_size_graphs = batch.batch.max().item() + 1
                # One timestep per graph
                t = torch.randint(0,
                                  steps, (batch_size_graphs, ),
                                  device=device).long()
                # Expand t to node-level: each node gets its graph's timestep
                t = torch.gather(t, 0, batch.batch)
                # clean node feature (positions + rotations, etc.)
                x_start = batch.x
                noise = torch.randn_like(x_start)
                x_noisy = gnn_diffusion.q_sample(x_start=x_start,
                                                 t=t,
                                                 noise=noise)
                # CNN features from image patches
                patch_feats = model.visual_features(batch.patches)
                # 6. Predict noise with GNN
                prediction, _ = model.forward_with_feats(
                    x_noisy, t, batch.patches, batch.edge_index, patch_feats,
                    batch.batch)
                # Compute validation lose
                # Target = true noise
                # Prediction = model's noise estimate
                val_loss = criterion(noise, prediction)
                # Store scalar loss
                val_losses.append(val_loss.item())

            val_loss = np.mean(val_losses)

            # -------- LOGGING --------
            run.log({
                "epoch": epoch + 1,
                "train/loss": train_loss,
                "val/loss": val_loss
            })

            print(
                f"Epoch [{epoch+1}/{epochs}] Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f}"
            )

            # Step scheduler
            if scheduler is not None:
                scheduler.step()

            # -------------------------
            # Checkpoint save
            # -------------------------
            checkpoint = {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_loss": train_loss,
                "val_loss": val_loss
            }

            if scheduler is not None:
                checkpoint["scheduler_state_dict"] = scheduler.state_dict()

            # Save latest checkpoint (resume)
            torch.save(checkpoint, latest_ckpt)

            # Save periodic snapshot
            # Save every 5 epochs and the last one
            if (epoch + 1) % 5 == 0 or (epoch + 1) == epochs:
                snapshot_path = checkpoint_dir / f"epoch_{epoch+1}.pt"
                torch.save(checkpoint, snapshot_path)
                print(f"Saved checkpoint snapshot: {snapshot_path}")

    # -------------------------
    # Finish W&B
    # -------------------------
    run.finish()


In [None]:
train_model(batch_size=10,steps=2,epochs=100)