# How to train your Global Workspace
Benjamin Devillers

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ruflab/shimmer-tutorials/blob/main/simple-shapes-dataset-inference.ipynb)


In this notebook, we will see how to use `shimmer` to build and train from scratch a Global Workspace on the Simple Shapes Dataset. We train a model than can translate visual images of shapes from the [simple-shapes-datset](https://github.com/ruflab/simple-shapes-dataset) to their proto-language (attributes).

For this tutorial, we will need to install the [shimmer-ssd](https://github.com/ruflab/shimmer-ssd) package.

In [None]:
!pip install "git+https://github.com/ruflab/shimmer-ssd.git"

In [8]:
!pip install tensorboard

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


This package depends on [simple-shapes-dataset](https://github.com/ruflab/simple-shapes-dataset) and provides all of its commands. You can then use all of its commands.

For instance, we can download the dataset directly with:

In [None]:
!shapesd download

Note that `shapesd download` automatically migrates the dataset so that it is correctly formatted. If you downloaded the dataset manually, use `shapesd migrate -p PATH_TO_DATASET` to migrate manually.

In [1]:
from collections.abc import Mapping
from pathlib import Path
from typing import Any, cast

import matplotlib
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import Callback, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from shimmer import DomainModule, LossOutput
from shimmer.modules.global_workspace import GlobalWorkspace2Domains
from shimmer.modules.vae import VAE, gaussian_nll, kl_divergence_loss
from shimmer_ssd import DEBUG_MODE, LOGGER, PROJECT_DIR
from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig, load_config
from shimmer_ssd.dataset.pre_process import TokenizeCaptions
from shimmer_ssd.logging import LogVisualCallback, batch_to_device
from shimmer_ssd.modules.domains import load_pretrained_domains
from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule
from shimmer_ssd.modules.vae import RAEDecoder, RAEEncoder
from tokenizers.implementations.byte_level_bpe import ByteLevelBPETokenizer
from torch.nn.functional import mse_loss
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.optimizer import Optimizer
from torchvision.utils import make_grid

from simple_shapes_dataset import SimpleShapesDataModule, get_default_domains

%matplotlib inline

## Config

Let's first generate the config folder for the rest of the scripts.
This will create a `config` folder with different yaml files used by the different scripts and in the notebook.

In [2]:
!ssd config create

Config folder already exists. Skipping.


This will create a `config` folder. This contains many file, but in this tutorial, only `main.yaml` will interest us.

You can start by taking a look at the default values which should be mostly set correctly for this tutorial. But you can try and make some changes to see the outcome.

<div class="alert alert-info">
Anytime you make a change to the config, don't forget to reload it with the following cell!
</div>

In [3]:
# We don't use cli in the notebook, but consider using it in normal scripts.
config = load_config("./config", use_cli=False)

## Data format

The dataloader provides the data in a specific format:

```python
domain_group = {
    "domain": domain_data
}
batch = {
    frozenset(["domain"]): domain_group
}
```
* The **batch** is a dict that has frozensets of domains as keys, and a domain group as values.
* The **domain group** is a dict that has domains (string) as keys, and the domain data as values. The data samples of every domain in a domain group is matched. This
means that for a domain group that has 2 domains d1 and d2: `domain_group["d1"][k]` is paired with `domain_group["d2"][k]` for all `k`.

This allows a batch to have several groups (of different domains) of paired data. For example, a batch with unpaired visual (domain "v"), unpaired attribute (domain "attr"), and paired visual and attribute will look like:
```python
batch = {
    frozenset(["v"]): {"v": unpaired_visual_data},
    frozenset(["attr"]): {"attr": unpaired_attribute_data},
    frozenset(["attr", "v"]): {"attr": paired_attr_data, "v": paired_visual_data},
}
```

This is useful to train the global workspace later. But this is also the format used to train the unimodal domains.

Note that because all the data is paired in validation and test steps, the dataloader only returns one domain group with all paired domain:
```python
val_batch = {"attr": paired_attr_data, "v": paired_v_data}
```

## Unimodal Domains

Instead of using the domains already defined in [`shimmer-ssd`](https://ruflab.github.io/shimmer-ssd/latest/), we will redefine them from scratch to learn how to make them.

We will train 2 domain modules:
* vision: a beta-VAE that will encode the shape images into a small latent representation vector;
* attribute: a beta-VAE that will encode the object category and other attributes into a latent representation.


### Vision domain

Let's start by defining the data module to train the vision module.

We will load the config with an extra argument that will load the `train_v.yaml` file. This file contains specific configuration related to training the visual side like learning rate, weight decay, and the number of steps to train for.

In [4]:
config = load_config("./config", use_cli=False, load_files=["train_v.yaml"])

Note that with `load_files=["train_v.yaml"]`, the `main.yaml` file is still loaded and you don't have to repeat it. If you repeat some config values, the settings set in `train_v.yaml` will override the values in `main.yaml`.

In [5]:
seed_everything(config.seed, workers=True)

data_module = SimpleShapesDataModule(
    config.dataset.path,
    get_default_domains(["v"]),
    {frozenset(["v"]): 1.0},
    batch_size=config.training.batch_size,  # set in `config/train_v.yaml`
    max_train_size=config.dataset.max_train_size,
    num_workers=config.training.num_workers,
)

Seed set to 0


All domain modules must inherit from the class `DomainModule` from `shimmer`. This class requires the creation of several methods.
`DomainModule`s inherit from a `LightningModule`, so you can also define the relevant method to train it.

In particular, you will need to overide the following methods:
* `def encode(self, x: Any) -> torch.Tensor`: encodes the raw domain data into the unimodal latent representation
* `def decode(self, z: torch.Tensor) -> Any`: decodes the unimodal latent representation into the raw input data
* `def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput | None`: how to compute the loss used for cycle constency or translation
of this domain. `pred` is the predicted unimodal latent vector, `target` is the target latent vector and `raw_target` is the original input before being encoded.

You can see more details in [the docs](https://ruflab.github.io/shimmer/latest/shimmer/modules/domain.html#DomainModule).

In addition, we will also define `configure_optimizers`, `training_step`, and `validation_step`, to define the optimizer, and what happens during the training and validation steps. See the [lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) for more details.

In [6]:
class VisualDomainModule(DomainModule):
    def __init__(self, latent_dim: int, ae_dim: int, beta: float = 1.0, total_steps: int = 1):
        """
        Visual domain module. This defines shimmer's `DomainModule` for the vision
        side with a VAE.

        Args:
            latent_dim (`int`): latent dimension of the vision domain
            ae_dim (`int`): internal auto-encoder dimension of the VAE
            beta (`float`): beta value if beta-VAE. (Defaults to 1.0)
            total_steps (`int`): total number of training steps. This is useful for the OneCycle Scheduler.
        """

        # The parent class requires to know the latent dimension of the module.
        super().__init__(latent_dim)
        self.save_hyperparameters()

        num_channels = 3  # RGB channels
        vae_encoder = RAEEncoder(num_channels, ae_dim, latent_dim, use_batchnorm=True)
        vae_decoder = RAEDecoder(num_channels, latent_dim, ae_dim)

        # This is a useful helper to make a VAE. You can learn more on this here if interested:
        # https://ruflab.github.io/shimmer-ssd/latest/shimmer_ssd/modules/vae.html
        self.vae = VAE(vae_encoder, vae_decoder, beta)

        self.total_steps = total_steps

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode from the image to the latent representation.
        Here we can just use the encode function from the VAE.
        """
        return self.vae.encode(x)

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode the unimodal latent into the original domain.
        We can use the decode function from the VAE.
        """
        return self.vae.decode(z)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))

    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput:
        """
        Compute MSE loss in the latent domain. This will be usefull when training the global workspace.
        This defines the loss to use for the cycle consistency losses and the translation losses.
        """
        loss = mse_loss(pred, target, reduction="mean")
        return LossOutput(loss)

    # Pytorch Lightning related methods

    def training_step(self, batch: Mapping[frozenset[str], Mapping[str, torch.Tensor]], batch_idx: int) -> torch.Tensor:
        """
        Args:
            batch (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`): batch of domain groups. When training the visual domain,
                this will only contain one group with unpaired visual data
            batch_idx (`int`): batch index
        Returns:
            `torch.Tensor`: the total loss of the step
        """
        x = batch[frozenset(["v"])]["v"]  # extracts the visual information
        return self.generic_step(x, "train")  # use a generic step method that handles both train and validation.

    def validation_step(self, batch: Mapping[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """
        Validation step

        Args:
            batch (`Mapping[str, torch.Tensor]`): group with only paired data. The validation step does not receive several group.
            batch_idx (`int`): batch index

            Returns:
                `torch.Tensor`: the total loss
        """
        x = batch["v"]  # extract the visual data of the group
        return self.generic_step(x, "val")  # use `generic_step` method to compute the val loss.

    def generic_step(self, x: torch.Tensor, mode: str = "train") -> torch.Tensor:
        """
        Computes the loss given image data

        Args:
            x (`torch.Tensor`): tensors of images
            mode (`str`): mode of step ("train", "val") used for logging
        Returns:
            `torch.Tensor`: The computed loss
        """
        # Get the latent mean, log-variance, and reconstructed images
        # from the VAE
        (mean, logvar), reconstruction = self.vae(x)

        reconstruction_loss = gaussian_nll(reconstruction, torch.tensor(0), x).sum()
        kl_loss = kl_divergence_loss(mean, logvar)
        total_loss = reconstruction_loss + self.vae.beta * kl_loss

        # Log losses
        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
        self.log(f"{mode}/kl_loss", kl_loss)
        self.log(f"{mode}/loss", total_loss)

        return total_loss

    def configure_optimizers(self) -> dict[str, Any]:
        """
        Defines optimizer and learning rate scheduler
        """

        learning_rate = 1e-3
        optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, weight_decay=1e-6)
        lr_scheduler = OneCycleLR(optimizer, max_lr=learning_rate, total_steps=self.total_steps)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",  # OneCycleLR requires update every step and not every epoch
            },
        }

Now let's instanciate the domain module

In [7]:
v_domain_module = VisualDomainModule(
    config.domain_modules.visual.latent_dim,
    config.domain_modules.visual.ae_dim,
    config.domain_modules.visual.beta,
    config.training.max_steps,
)

In [8]:
logger = TensorBoardLogger("logs", name="vision_model")

val_samples = data_module.get_samples("val", 32)[frozenset(["v"])]["v"]
train_samples = data_module.get_samples("train", 32)[frozenset(["v"])]["v"]

callbacks: list[Callback] = [
    LogVisualCallback(
        val_samples,
        log_key="val_images",
        mode="val",
        every_n_epochs=config.logging.log_val_medias_every_n_epochs,
        ncols=8,
    ),
    LogVisualCallback(
        train_samples,
        log_key="train_images",
        mode="train",
        every_n_epochs=config.logging.log_train_medias_every_n_epochs,
        ncols=8,
    ),
    ModelCheckpoint(
        dirpath=config.default_root_dir / f"{logger.version}",
        filename="{epoch}",
        monitor="val/loss",
        mode="min",
        save_top_k=1,
    ),
]

In [16]:
print(f".logs/vision_model/version_{logger.version}")

.logs/vision_model/version_3


In [9]:
%load_ext tensorboard

In [12]:
%tensorboard --logdir "./logs/vision_model/version_2" --reload_interval 30 --reload_task 'auto'

Reusing TensorBoard on port 6006 (pid 3428984), started 0:00:55 ago. (Use '!kill 3428984' to kill it.)

In [9]:
trainer = Trainer(
    logger=logger,
    fast_dev_run=config.training.fast_dev_run,
    max_steps=config.training.max_steps,
    enable_progress_bar=config.training.enable_progress_bar,
    default_root_dir=config.default_root_dir,
    callbacks=callbacks,
    precision=config.training.precision,
    accelerator=config.training.accelerator,
    devices=config.training.devices,
)

trainer.fit(v_domain_module, data_module)
trainer.validate(v_domain_module, data_module, "best")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type | Params | Mode 
--------------------------------------
0 | vae  | VAE  | 1.5 M  | train
--------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.990     Total estimated model params size (MB)
32        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [3]:
!shapesd alignment add --dataset_path "simple_shapes_dataset" --seed 0 --domain_alignment attr 1.0 --domain_alignment v 1.0 --domain_alignment attr,v 0.5 --ms 500000

In [4]:
checkpoint_path = Path("/home/bdevillers/projects/shimmer-ssd/checkpoints/pretrained/")

In [5]:
seed = 0
dataset_path = Path("simple_shapes_dataset")

seed_everything(seed, workers=True)

# Let's define the proportion of each groups of domain.
# Here, we say that we will use 100% of unimodal text domain samples
# 100% of the unimodal visual samples
# and 100% of the paired text-visual samples.
domain_proportion = {
    frozenset(["attr"]): 1.0,
    frozenset(["v"]): 1.0,
    frozenset(["attr", "v"]): 0.5,
}

# Then we get the data domain that will load the data for our two different domain
# formats:
# - `t` will load the text (with bert representations of the sentences
# - `v_latents` will load pre-saved latent representations of the visual VAE.
# You could also use `v` to load the images directly. Using `v_latents` will provide
# faster training as we can skip the VAE encoding.
domain_classes = get_default_domains(["v_latents", "attr"])

# here are some additional arguments used by the data domain classes
domain_data_args = {
    # where to load the pre-saved latent reprensentations
    "v_latents": {"presaved_path": "vae_v_shimmer.npy"},
}

batch_size = 1024
num_workers = 16

# initilize the lightning data module
data_module = SimpleShapesDataModule(
    dataset_path,
    domain_classes,
    domain_proportion,
    batch_size,
    num_workers=num_workers,
    seed=seed,
    domain_args=domain_data_args,
    max_train_size=500_000,
)

Seed set to 0


In [6]:
# We need to load pretrained unimodal domains.
# Here we define where are stored the domain checkpoints.
domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.attr,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
    ),
]

global_workspace_dim = 16
encoders_hidden_dim = 32
encoders_n_layers = 3
decoders_hidden_dim = 32
decoders_n_layers = 3

# we load the pretrained domain modules and define the associated GW encoders and
# decoders
domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    domains,
    global_workspace_dim,
    encoders_hidden_dim,
    encoders_n_layers,
    decoders_hidden_dim,
    decoders_n_layers,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
loss_coefficients = {
    "demi_cycles": 1.0,
    "cycles": 1.0,
    "translations": 1.0,
    "contrastives": 0.01,
}

# Training options
learning_rate = 3e-3
weight_decay = 1e-6
max_steps = 1000


def get_scheduler(optimizer: Optimizer) -> OneCycleLR:
    return OneCycleLR(optimizer, learning_rate, max_steps)


global_workspace = GlobalWorkspace2Domains(
    domain_modules,
    gw_encoders,
    gw_decoders,
    global_workspace_dim,
    loss_coefficients,
    learning_rate,
    weight_decay,
    scheduler=get_scheduler,
)

In [8]:
trainer = Trainer(
    max_steps=max_steps,
    default_root_dir="checkpoints",
    accelerator="gpu",
    devices="1",
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/bdevillers/.cache/pypoetry/virtualenvs/shimmer-tutorials-RRvxJ_Ue-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [9]:
trainer.fit(global_workspace, data_module)

You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type                  | Params | Mode 
----------------------------------------------------------------
0 | gw_mod        | GWModule              | 5.9 M  | train
1 | selection_mod | SingleDomainSelection | 0      | train
2 | loss_mod      | GWLosses2Domains      | 5.9 M  | train
----------------------------------------------------------------
16.4 K    Trainable params
5.8 M     Non-trainable params
5.9 M     Total params
23.430    Total estimated model params size (MB)
47        Modules in train mode
57        Modules in eval mode


Sanity Checking: |                                                                                            …

/home/bdevillers/.cache/pypoetry/virtualenvs/shimmer-tutorials-RRvxJ_Ue-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=123` in the `DataLoader` to improve performance.
/home/bdevillers/.cache/pypoetry/virtualenvs/shimmer-tutorials-RRvxJ_Ue-py3.11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=123` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined