In [1]:
from pathlib import Path
from typing import cast

import matplotlib
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import Trainer, seed_everything
from shimmer.modules.global_workspace import GlobalWorkspace2Domains
from shimmer_ssd import DEBUG_MODE, PROJECT_DIR
from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig, load_config
from shimmer_ssd.dataset.pre_process import TokenizeCaptions
from shimmer_ssd.logging import batch_to_device
from shimmer_ssd.modules.domains import load_pretrained_domains
from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule
from tokenizers.implementations.byte_level_bpe import ByteLevelBPETokenizer
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.optimizer import Optimizer
from torchvision.utils import make_grid

from simple_shapes_dataset import SimpleShapesDataModule, get_default_domains

matplotlib.use("ipympl")

We will first download the dataset from Zenodo if does not exist:

In [2]:
!shapesd download

Dataset already exists. Skipping download. Use `--force` to download anyway.


In [3]:
!shapesd alignment add --dataset_path "simple_shapes_dataset" --seed 0 --domain_alignment attr 1.0 --domain_alignment v 1.0 --domain_alignment attr,v 0.5 --ms 500000

In [4]:
checkpoint_path = Path("/home/bdevillers/projects/shimmer-ssd/checkpoints/pretrained/")

In [5]:
seed = 0
dataset_path = Path("simple_shapes_dataset")

seed_everything(seed, workers=True)

# Let's define the proportion of each groups of domain.
# Here, we say that we will use 100% of unimodal text domain samples
# 100% of the unimodal visual samples
# and 100% of the paired text-visual samples.
domain_proportion = {
    frozenset(["attr"]): 1.0,
    frozenset(["v"]): 1.0,
    frozenset(["attr", "v"]): 0.5,
}

# Then we get the data domain that will load the data for our two different domain
# formats:
# - `t` will load the text (with bert representations of the sentences
# - `v_latents` will load pre-saved latent representations of the visual VAE.
# You could also use `v` to load the images directly. Using `v_latents` will provide
# faster training as we can skip the VAE encoding.
domain_classes = get_default_domains(["v_latents", "attr"])

# here are some additional arguments used by the data domain classes
domain_data_args = {
    # where to load the pre-saved latent reprensentations
    "v_latents": {"presaved_path": "vae_v_shimmer.npy"},
}

batch_size = 1024
num_workers = 16

# initilize the lightning data module
data_module = SimpleShapesDataModule(
    dataset_path,
    domain_classes,
    domain_proportion,
    batch_size,
    num_workers=num_workers,
    seed=seed,
    domain_args=domain_data_args,
    max_train_size=500_000,
)

Seed set to 0


In [6]:
# We need to load pretrained unimodal domains.
# Here we define where are stored the domain checkpoints.
domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.attr,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
    ),
]

global_workspace_dim = 16
encoders_hidden_dim = 32
encoders_n_layers = 3
decoders_hidden_dim = 32
decoders_n_layers = 3

# we load the pretrained domain modules and define the associated GW encoders and
# decoders
domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    domains,
    global_workspace_dim,
    encoders_hidden_dim,
    encoders_n_layers,
    decoders_hidden_dim,
    decoders_n_layers,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
loss_coefficients = {
    "demi_cycles": 1.0,
    "cycles": 1.0,
    "translations": 1.0,
    "contrastives": 0.01,
}

# Training options
learning_rate = 3e-3
weight_decay = 1e-6
max_steps = 1000


def get_scheduler(optimizer: Optimizer) -> OneCycleLR:
    return OneCycleLR(optimizer, learning_rate, max_steps)


global_workspace = GlobalWorkspace2Domains(
    domain_modules,
    gw_encoders,
    gw_decoders,
    global_workspace_dim,
    loss_coefficients,
    learning_rate,
    weight_decay,
    scheduler=get_scheduler,
)

In [8]:
trainer = Trainer(
    max_steps=max_steps,
    default_root_dir="checkpoints",
    accelerator="gpu",
    devices="1",
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/bdevillers/.cache/pypoetry/virtualenvs/shimmer-tutorials-RRvxJ_Ue-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [9]:
trainer.fit(global_workspace, data_module)

You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type                  | Params | Mode 
----------------------------------------------------------------
0 | gw_mod        | GWModule              | 5.9 M  | train
1 | selection_mod | SingleDomainSelection | 0      | train
2 | loss_mod      | GWLosses2Domains      | 5.9 M  | train
----------------------------------------------------------------
16.4 K    Trainable params
5.8 M     Non-trainable params
5.9 M     Total params
23.430    Total estimated model params size (MB)
47        Modules in train mode
57        Modules in eval mode


Sanity Checking: |                                                                                            …

/home/bdevillers/.cache/pypoetry/virtualenvs/shimmer-tutorials-RRvxJ_Ue-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=123` in the `DataLoader` to improve performance.
/home/bdevillers/.cache/pypoetry/virtualenvs/shimmer-tutorials-RRvxJ_Ue-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=123` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined