# Setup variables

Define the main variables

In [None]:
# Set this to True if this notebook runs in Colab and GPU is enabled.
# This variable will be ignored if you're not in Colab
GPU_IS_ENABLED_IN_COLAB = True

# Enable/Disable Wandb
# When enabled, you'll need to login to Wandb through a terminal before running this notebook.
# To login to Wandb, run the command 'wandb login' in you terminal.
WANDB_DISABLED = True

In [None]:
import os

# Set all the variables as environment variable to be seen by bash scripts

os.environ["GPU_IS_ENABLED_IN_COLAB"] = "1" if GPU_IS_ENABLED_IN_COLAB else "0"
os.environ["WANDB_DISABLED"] = "true" if WANDB_DISABLED else "false"

In [None]:
import sys
import os

IN_COLAB = "google.colab" in sys.modules
print(IN_COLAB)

# Set IN_COLAB as an environment variable to be seen by bash scripts
os.environ["IN_COLAB"] = "1" if IN_COLAB else "0"

# Installation

Verify your python version.
Note that this notebook has been tested with python version 3.12.3

In [None]:
!python --version

âš  Install the required packages only if this notebook runs in Colab. Otherwise you should install the required packages manually on your local python environment.

Install the version 2.6.0 of torch version to be able to install later compatible pytorch-geometric packages

In [None]:
if IN_COLAB:
  !pip uninstall -y torch torchvision torchaudio
  if GPU_IS_ENABLED_IN_COLAB:
    !pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126
  else:
    !pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cpu

In [None]:
import torch

if IN_COLAB:
  torch_version = torch.__version__.split('+')[0]
  if GPU_IS_ENABLED_IN_COLAB:
    cuda_version = torch.version.cuda.replace('.', '')
    !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-{torch_version}+cu{cuda_version}.html

  else:
    !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-{torch_version}+cpu.html

  !pip install -q torch-geometric

In [None]:
!pip install wandb

## Wandb login

In [None]:
if not WANDB_DISABLED:
    !wandb login

## Inspect runtime default versions and settings

Check torch and torchvision default versions. For now we are just going to use them, we'll change them if we hit any conflict in the future.

In [None]:
import torch
import torchvision

print(f"Torch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print("")
print(f"Torch cuda is available: {torch.cuda.is_available()}")

If cuda is not available, enable GPU in Colab by going to 'Runtime' > 'Change runtime type' > Select 'T4 GPU'.

This will restart the session and you'll need to rerun all the cells again. After restarting the session, verify that cuda is available.

### Nvidia version

The following command (nvidia-smi) will tell you which GPU you are using (if any).

In [None]:
!nvidia-smi

## Enable cuda if available

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Import our code

Import the classes of the deep-learning-puzzle-project repository

In [None]:
%%bash

if [ "$IN_COLAB" == "0" ]; then
  echo "Skipping download (IN_COLAB is false)"
  exit 0
fi

REPO_DIR_NAME="deep-learning-puzzle-project"

rm -r ${REPO_DIR_NAME}

git clone https://github.com/silviasuhu/deep-learning-puzzle-project.git


Move to the root directory of the repo to get consistency between Colab and local executions

In [None]:
import os
import sys

current_dir = os.path.basename(os.getcwd())
print(current_dir)

if current_dir == "deep-learning-puzzle-project":
    print("You are already on the root directory of the 'deep-learning-puzzle-project' repo.")

else:
  if IN_COLAB:
    %cd "deep-learning-puzzle-project"
  else:
    %cd ".."

## Main imports

In [None]:
import os
from pathlib import Path
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
import numpy as np
import einops

import torch
import torch_geometric
import torchvision

# Dataset download

Download the dataset only if this notebook runs in Colab, otherwise you'll need to download it manually.

In [None]:
import os
os.environ["IN_COLAB"] = "1" if IN_COLAB else "0"

In [None]:
import os

# Assuming that we are on the root directory of the repo
DATASET_PATH="data/CelebA-HQ"

os.environ["DATASET_PATH"] = DATASET_PATH

In [None]:
%%bash

if [ "$IN_COLAB" == "0" ]; then
  echo "Skipping download (IN_COLAB is false)"
  exit 0
fi

echo 'Downloading dataset...'

OUTPUT_FILENAME='dataset.zip'
FOLDER_NAME='CelebAMask-HQ'

mkdir -p ${DATASET_PATH}

if [ -d ${DATASET_PATH}/${FOLDER_NAME} ]; then
  echo "Skipping the download since the folder ${DATASET_PATH}/${FOLDER_NAME} already exists"
  exit 0
fi

rm ${OUTPUT_FILENAME}
rm -r ${FOLDER_NAME}
wget --no-check-certificate 'https://huggingface.co/datasets/liusq/CelebAMask-HQ/resolve/main/CelebAMask-HQ.zip?download=true' -O ${OUTPUT_FILENAME}
echo "${OUTPUT_FILENAME} downloaded. Unziping it..."
unzip ${OUTPUT_FILENAME}
rm ${OUTPUT_FILENAME}

mv ${FOLDER_NAME} ${DATASET_PATH}

echo "Done"

Preview an image

In [None]:
img = Image.open(DATASET_PATH + "/CelebAMask-HQ/CelebA-HQ-img/1000.jpg")
plt.imshow(img)
plt.axis("off")

In [None]:
ls

Download the txt files from the DiffAssemble repository that define the data split between training and testing

In [None]:
%%bash

if [ "$IN_COLAB" == "0" ]; then
  echo "Skipping download (IN_COLAB is false)"
  exit 0
fi

[ -f CelebA-HQ_test.txt ] && rm CelebA-HQ_test.txt
[ -f CelebA-HQ_train.txt ] && rm CelebA-HQ_train.txt

wget -q https://raw.githubusercontent.com/IIT-PAVIS/DiffAssemble/refs/heads/release/datasets/data_splits/CelebA-HQ_test.txt
wget -q https://raw.githubusercontent.com/IIT-PAVIS/DiffAssemble/refs/heads/release/datasets/data_splits/CelebA-HQ_train.txt

mkdir -p $DATASET_PATH
mv CelebA-HQ_test.txt $DATASET_PATH
mv CelebA-HQ_train.txt $DATASET_PATH

ls $DATASET_PATH

In [None]:
# Let's tell to this notebook that we may need to import python packages from the 'src' folder
import sys

sys.path.append("src")

In [None]:
from dataset_celeb_rot import CelebA_DataSet
from puzzle_dataset import Puzzle_Dataset_ROT

# Dataset testing

In [None]:
dataset = CelebA_DataSet(path=DATASET_PATH, train=True)
img = dataset[0]
plt.imshow(img)
plt.axis("off")

Inspect the output of Puzzle_Dataset_ROT

###Interesting points

- **Number of patches**. We'll see the image has been splited accordingly with the value assigned to the 'patch_per_dim' parameter. For instance, if patch_per_dim is [(6,6)], we'll see 36 patches per image.

- **Rotation**.
TODO

In [None]:
train_dt = CelebA_DataSet(DATASET_PATH, train=True)

puzzle_dt = Puzzle_Dataset_ROT(dataset=train_dt,patch_per_dim=[(6,6)], augment=False, degree=-1, unique_graph=None, all_equivariant=False, random_dropout=False)

In [None]:
elem=puzzle_dt[0]

print(elem)
print(f"X: {elem.x}")
print(f"EDGE_INDEX: {elem.edge_index}")
print(f"INDEXES: {elem.indexes}")
print(f"ROT: {elem.rot}")
print(f"ROT_INDEX: {elem.rot_index}")
print(f"IND_NAME: {elem.ind_name}")

In [None]:
# Print original image
idx = 0

plt.imshow(puzzle_dt.dataset[idx])
plt.axis("off")

In [None]:
from torchvision.utils import make_grid

graph=puzzle_dt[idx]

# rotIdx=3
# patches = graph.patches[:, rotIdx]

grid = make_grid(graph.patches, nrow=6, padding=2)

# Convert CHW -> HWC for matplotlib
grid = grid.permute(1, 2, 0)

plt.figure(figsize=(12, 12))
plt.imshow(grid)
plt.axis("off")
plt.show()

Let's inspect the Dataloader too..

In [None]:
dataset = Puzzle_Dataset_ROT(dataset=train_dt, patch_per_dim=[(6,6)], augment=False, degree=-1, unique_graph=None, all_equivariant=True, random_dropout=False)

BATCH_SIZE=10
dataloader = torch_geometric.loader.DataLoader(
  dataset, batch_size=BATCH_SIZE, shuffle=True
)

first_batch = next(iter(dataloader))

# Let's compare the dataset structure with the dataloader batch structure
print(dataset[0])
print(first_batch)

# As you'll see, the first dimension of each parameter has been multiplied by the batch_size.

# x contains...
# edge_index contains...
# indexes contains...
# rot contains...
# rot_index contains...
# patches contains the image patches rotated 0,90,180 or 270 degrees
# ind_name contains...
# patches_dim contains the number of patches in the x and in the y axis.
# batch contains...


# Training

In [None]:
%%bash

STEPS=10
BATCH_SIZE=10
EPOCHS=100
PUZZLE_SIZES="6"

args=()

args+=("-batch_size=$BATCH_SIZE")
args+=("-steps=$STEPS")
args+=("-epochs=$EPOCHS")
args+=("-puzzle_sizes=$PUZZLE_SIZES")

if [ "$WANDB_DISABLED" == "true" ]; then
    args+=("-wandb_disabled")
fi

echo "ARGS: ${args[@]}"

pushd ${REPO_DIR}
# python src/train_script.py -wandb_disabled
python src/train_script.py "${args[@]}"
popd