# How to train your Global Workspace
Benjamin Devillers

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ruflab/shimmer-tutorials/blob/main/simple-shapes-dataset-training.ipynb)


In this notebook, we will see how to use `shimmer` to build and train from scratch a Global Workspace on the Simple Shapes Dataset. We train a model than can translate visual images of shapes from the [simple-shapes-datset](https://github.com/ruflab/simple-shapes-dataset) to their proto-language (attributes).

For this tutorial, we will need to install the [shimmer-ssd](https://github.com/ruflab/shimmer-ssd) package.

In [None]:
!pip install "git+https://github.com/ruflab/shimmer-ssd.git"

In [8]:
!pip install tensorboard

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m101.7/101.7 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


This package depends on [simple-shapes-dataset](https://github.com/ruflab/simple-shapes-dataset) and provides all of its commands. You can then use all of its commands.

For instance, we can download the dataset directly with:

In [None]:
!shapesd download

Note that `shapesd download` automatically migrates the dataset so that it is correctly formatted. If you downloaded the dataset manually, use `shapesd migrate -p PATH_TO_DATASET` to migrate manually.

In [1]:
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any, cast

import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from lightning.pytorch import Callback, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from shimmer import DomainModule, LossOutput
from shimmer.modules.domain import DomainModule
from shimmer.modules.global_workspace import GlobalWorkspace2Domains, SchedulerArgs
from shimmer.modules.vae import (
    VAE,
    VAEDecoder,
    VAEEncoder,
    gaussian_nll,
    kl_divergence_loss,
)
from shimmer_ssd import DEBUG_MODE, LOGGER, PROJECT_DIR
from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig, load_config
from shimmer_ssd.dataset.pre_process import TokenizeCaptions
from shimmer_ssd.logging import (
    LogAttributesCallback,
    LogGWImagesCallback,
    LogVisualCallback,
    batch_to_device,
)
from shimmer_ssd.modules.domains import load_pretrained_domains
from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule
from shimmer_ssd.modules.vae import RAEDecoder, RAEEncoder
from tokenizers.implementations.byte_level_bpe import ByteLevelBPETokenizer
from torch import nn
from torch.nn.functional import mse_loss
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.optimizer import Optimizer
from torchvision.utils import make_grid

from simple_shapes_dataset import SimpleShapesDataModule, get_default_domains


%matplotlib inline

## Config

Let's first generate the config folder for the rest of the scripts.
This will create a `config` folder with different yaml files used by the different scripts and in the notebook.

In [2]:
!ssd config create

Config folder already exists. Skipping.


This will create a `config` folder. This contains many file, but in this tutorial, only `main.yaml` will interest us.

You can start by taking a look at the default values which should be mostly set correctly for this tutorial. But you can try and make some changes to see the outcome.

<div class="alert alert-info">
Anytime you make a change to the config, don't forget to reload it with the following cell!
</div>

In [3]:
# We don't use cli in the notebook, but consider using it in normal scripts.
config = load_config("./config", use_cli=False)

## Data format

The dataloader provides the data in a specific format:

```python
domain_group = {
    "domain": domain_data
}
batch = {
    frozenset(["domain"]): domain_group
}
```
* The **batch** is a dict that has frozensets of domains as keys, and a domain group as values.
* The **domain group** is a dict that has domains (string) as keys, and the domain data as values. The data samples of every domain in a domain group is matched. This
means that for a domain group that has 2 domains d1 and d2: `domain_group["d1"][k]` is paired with `domain_group["d2"][k]` for all `k`.

This allows a batch to have several groups (of different domains) of paired data. For example, a batch with unpaired visual (domain "v"), unpaired attribute (domain "attr"), and paired visual and attribute will look like:
```python
batch = {
    frozenset(["v"]): {"v": unpaired_visual_data},
    frozenset(["attr"]): {"attr": unpaired_attribute_data},
    frozenset(["attr", "v"]): {"attr": paired_attr_data, "v": paired_visual_data},
}
```

This is useful to train the global workspace later. But this is also the format used to train the unimodal domains.

Note that because all the data is paired in validation and test steps, the dataloader only returns one domain group with all paired domain:
```python
val_batch = {"attr": paired_attr_data, "v": paired_v_data}
```

## Unimodal Domains

Instead of using the domains already defined in [`shimmer-ssd`](https://ruflab.github.io/shimmer-ssd/latest/), we will redefine them from scratch to learn how to make them.

We will train 2 domain modules:
* vision: a beta-VAE that will encode the shape images into a small latent representation vector;
* attribute: a beta-VAE that will encode the object category and other attributes into a latent representation.


### Vision domain

Let's start by defining the data module to train the vision module.

We will load the config with an extra argument that will load the `train_v.yaml` file. This file contains specific configuration related to training the visual side like learning rate, weight decay, and the number of steps to train for.

In [4]:
config = load_config("./config", use_cli=False, load_files=["train_v.yaml"])

Note that with `load_files=["train_v.yaml"]`, the `main.yaml` file is still loaded and you don't have to repeat it. If you repeat some config values, the settings set in `train_v.yaml` will override the values in `main.yaml`.

In [5]:
seed_everything(config.seed, workers=True)

data_module = SimpleShapesDataModule(
    config.dataset.path,
    get_default_domains(["v"]),
    {frozenset(["v"]): 1.0},
    batch_size=config.training.batch_size,  # set in `config/train_v.yaml`
    num_workers=config.training.num_workers,
)

Seed set to 0


All domain modules must inherit from the class `DomainModule` from `shimmer`. This class requires the creation of several methods.
`DomainModule`s inherit from a `LightningModule`, so you can also define the relevant method to train it.

In particular, you will need to overide the following methods:
* `def encode(self, x: Any) -> torch.Tensor`: encodes the raw domain data into the unimodal latent representation
* `def decode(self, z: torch.Tensor) -> Any`: decodes the unimodal latent representation into the raw input data
* `def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput | None`: how to compute the loss used for cycle constency or translation
of this domain. `pred` is the predicted unimodal latent vector, `target` is the target latent vector and `raw_target` is the original input before being encoded.

You can see more details in [the docs](https://ruflab.github.io/shimmer/latest/shimmer/modules/domain.html#DomainModule).

In addition, we will also define `configure_optimizers`, `training_step`, and `validation_step`, to define the optimizer, and what happens during the training and validation steps. See the [lightning docs](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) for more details.

In [6]:
class VisualDomainModule(DomainModule):
    def __init__(
        self,
        num_channels: int,
        latent_dim: int,
        ae_dim: int,
        beta: float = 1,
        optim_lr: float = 1e-3,
        optim_weight_decay: float = 0,
        scheduler_args: Mapping[str, Any] | None = None,
    ):
        """
        Visual domain module. This defines shimmer's `DomainModule` for the vision
        side with a VAE.

        Args:
            num_channels (`int`): number of input channels (for RGB image, use 3)
            latent_dim (`int`): latent dimension of the vision domain
            ae_dim (`int`): internal auto-encoder dimension of the VAE
            beta (`float`): beta value if beta-VAE. (Defaults to 1.0)
            optim_lr (`float`): training learning rate
            optim_weight_decay (`float`): training weight decay
            scheduler_args (`Mapping[str, Any] | None`): Args for the scheduler.
        """

        # The parent class requires to know the latent dimension of the module.
        super().__init__(latent_dim)
        self.save_hyperparameters()

        vae_encoder = RAEEncoder(num_channels, ae_dim, latent_dim, use_batchnorm=True)
        vae_decoder = RAEDecoder(num_channels, latent_dim, ae_dim)

        # This is a useful helper to make a VAE. You can learn more on this here if interested:
        # https://ruflab.github.io/shimmer-ssd/latest/shimmer_ssd/modules/vae.html
        self.vae = VAE(vae_encoder, vae_decoder, beta)
        self.optim_lr = optim_lr
        self.optim_weight_decay = optim_weight_decay
        self.scheduler_args: dict[str, Any] = {
            "max_lr": optim_lr,
            "total_steps": 1,
        }
        self.scheduler_args.update(scheduler_args or {})

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode from the image to the latent representation.
        Here we can just use the encode function from the VAE.
        """
        return self.vae.encode(x)

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode the unimodal latent into the original domain.
        We can use the decode function from the VAE.
        """
        return self.vae.decode(z)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))

    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput:
        """
        Compute MSE loss in the latent domain. This will be usefull when training the global workspace.
        This defines the loss to use for the cycle consistency losses and the translation losses.
        """
        loss = mse_loss(pred, target, reduction="mean")
        return LossOutput(loss)

    # Pytorch Lightning related methods

    def training_step(self, batch: Mapping[frozenset[str], Mapping[str, torch.Tensor]], batch_idx: int) -> torch.Tensor:
        """
        Args:
            batch (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`): batch of domain groups. When training the visual domain,
                this will only contain one group with unpaired visual data
            batch_idx (`int`): batch index
        Returns:
            `torch.Tensor`: the total loss of the step
        """
        x = batch[frozenset(["v"])]["v"]  # extracts the visual information
        return self.generic_step(x, "train")  # use a generic step method that handles both train and validation.

    def validation_step(self, batch: Mapping[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """
        Validation step

        Args:
            batch (`Mapping[str, torch.Tensor]`): group with only paired data. The validation step does not receive several group.
            batch_idx (`int`): batch index

            Returns:
                `torch.Tensor`: the total loss
        """
        x = batch["v"]  # extract the visual data of the group
        return self.generic_step(x, "val")  # use `generic_step` method to compute the val loss.

    def generic_step(self, x: torch.Tensor, mode: str = "train") -> torch.Tensor:
        """
        Computes the loss given image data

        Args:
            x (`torch.Tensor`): tensors of images
            mode (`str`): mode of step ("train", "val") used for logging
        Returns:
            `torch.Tensor`: The computed loss
        """
        # Get the latent mean, log-variance, and reconstructed images
        # from the VAE
        (mean, logvar), reconstruction = self.vae(x)

        reconstruction_loss = gaussian_nll(reconstruction, torch.tensor(0), x).sum()
        kl_loss = kl_divergence_loss(mean, logvar)
        total_loss = reconstruction_loss + self.vae.beta * kl_loss

        # Log losses
        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
        self.log(f"{mode}/kl_loss", kl_loss)
        self.log(f"{mode}/loss", total_loss)

        return total_loss

    def configure_optimizers(self) -> dict[str, Any]:
        """
        Defines optimizer and learning rate scheduler
        """

        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.optim_lr,
            weight_decay=self.optim_weight_decay,
        )
        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
            },
        }

Now let's instanciate the domain module

In [7]:
v_domain_module = VisualDomainModule(
    num_channels=3,
    ae_dim=config.domain_modules.visual.ae_dim,
    latent_dim=config.domain_modules.visual.latent_dim,
    beta=config.domain_modules.visual.beta,
    optim_lr=config.training.optim.lr,
    optim_weight_decay=config.training.optim.weight_decay,
    scheduler_args={
        "max_lr": config.training.optim.max_lr,
        "total_steps": config.training.max_steps,
    },
)

We will use tensorboard to log the losses and reconstructed images. We can use `LogVisualCallback` from `shimmer_ssd.logging` to
log reconstructed images of some image samples.

You can update the `train_v.yaml` config file to change how often images will be updated on tensorboard. Here we will put 1 so that they are updated every epoch:
```yaml
logging:
    log_val_medias_every_n_epochs: 1
    log_train_medias_every_n_epochs: 1
```

In [8]:
logger = TensorBoardLogger("logs", name="vision_model")

# Get some image samples to log in tensorboard.
val_samples = data_module.get_samples("val", 32)[frozenset(["v"])]["v"]
train_samples = data_module.get_samples("train", 32)[frozenset(["v"])]["v"]

# Create vision vision where we will save checkpoints
(config.default_root_dir / "vision").mkdir(exist_ok=True)

callbacks: list[Callback] = [
    # Will log the validation ground-truth and reconstructions during training
    LogVisualCallback(
        val_samples,
        log_key="val_images",
        mode="val",
        every_n_epochs=config.logging.log_val_medias_every_n_epochs,
        ncols=8,
    ),
    # Will log the training ground-truth and reconstructions during training
    LogVisualCallback(
        train_samples,
        log_key="train_images",
        mode="train",
        every_n_epochs=config.logging.log_train_medias_every_n_epochs,
        ncols=8,
    ),
    # Save the checkpoints
    ModelCheckpoint(
        dirpath=config.default_root_dir / "vision" / f"version_{logger.version}",
        filename="{epoch}",
        monitor="val/loss",
        mode="min",
        save_last="link",
        save_top_k=1,
    ),
]

For the final model, let's save where the model is saved:

In [9]:
visual_checkpoint = config.default_root_dir / "vision" / f"version_{logger.version}"
print(visual_checkpoint)

checkpoints/vision/version_2


Load tensorboard. You can select the version associated to the previous path. It will appear after the training is started.

In [10]:
%load_ext tensorboard

In [11]:
%tensorboard --logdir "./logs/vision_model" --reload_interval 30 --reload_task 'auto'

Let'st start the training!

In [12]:
trainer = Trainer(
    logger=logger,
    fast_dev_run=config.training.fast_dev_run,
    max_steps=config.training.max_steps,
    enable_progress_bar=config.training.enable_progress_bar,
    default_root_dir=config.default_root_dir,
    callbacks=callbacks,
    precision=config.training.precision,
    accelerator=config.training.accelerator,
    devices=config.training.devices,
)

trainer.fit(v_domain_module, data_module)
trainer.validate(v_domain_module, data_module, "best")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type | Params | Mode 
--------------------------------------
0 | vae  | VAE  | 1.5 M  | train
--------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.990     Total estimated model params size (MB)
32        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_steps=2000` reached.
Restoring states from the checkpoint path at /home/bdevillers/projects/shimmer-ssd-tutorials/checkpoints/vision/version_2/epoch=7.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at /home/bdevillers/projects/shimmer-ssd-tutorials/checkpoints/vision/version_2/epoch=7.ckpt


Validation: |                                                                                                 …

[{'val/reconstruction_loss': 5758197.0,
  'val/kl_loss': 37061.0546875,
  'val/loss': 5761903.5}]

For faster training of the global workspace, we can extract the visual latent feature with the following cell. Don't forget to update the path to the actual checkpoint, which is:

In [13]:
print(visual_checkpoint)

checkpoints/vision/version_2


In [15]:
!ssd extract v "checkpoints/vision/version_2/last.ckpt" -p "simple_shapes_dataset" --latent_name "domain_v_tuto.npy" --force

Latent file already exists. Overriding.
Saving train.
100%|█████████████████████████████████████████| 244/244 [00:11<00:00, 22.02it/s]
Saving in simple_shapes_dataset/saved_latents/train/domain_v_tuto.npy.
Saving val.
100%|███████████████████████████████████████████| 25/25 [00:01<00:00, 14.02it/s]
Saving in simple_shapes_dataset/saved_latents/val/domain_v_tuto.npy.
Saving test.
Traceback (most recent call last):                       | 0/25 [00:00<?, ?it/s]
  File "/home/bdevillers/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/util.py", line 345, in _run_finalizers
    finalizer()
  File "/home/bdevillers/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/util.py", line 266, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bdevillers/.pyenv/versions/3.11.4/lib/python3.11/multiprocessing/util.py", line 153, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/bdevillers/.pyenv/versions/3.11.

We now need to update the location of the presaved latent vectors in the config file. You can see that `domain_data_args` is defined as follows:

```yaml
domain_data_args:
    v_latents:
        presaved_path: domain_v_tuto.npy  # as defined in the `--latent_name`
```
let's change the selected domains:

```yaml
domains:
    - checkpoint_path: "./checkpoints/visual/version_0/last.ckpt"  # update to the actual version
      domain_type: v_latents
```

### Attribute domain

This will be very similar to the previous section. Here, we will focus on learning the attribute model.

We will load the config with an extra argument that will load the `train_attr.yaml` file.

In [2]:
config = load_config("./config", use_cli=False, load_files=["train_attr.yaml"])

In [3]:
seed_everything(config.seed, workers=True)

data_module = SimpleShapesDataModule(
    config.dataset.path,
    get_default_domains(["attr"]),
    {frozenset(["attr"]): 1.0},
    batch_size=config.training.batch_size,  # set in `config/train_attr.yaml`
    num_workers=config.training.num_workers,
)

Seed set to 0


Similarly to what has been done for the visual domain, we will create a `DomainModule`. This will also use a VAE to encode and decode the attribute vectors.

First, we will define the VAE encoders and decoders. They will inherit from `shimme.modules.vae.VAEEncoder` and `shimme.modules.vae.VAEEncoder`.

For the encoder:

In [4]:
class Encoder(VAEEncoder):
    def __init__(
        self,
        hidden_dim: int,
        out_dim: int,
    ):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.out_dim = out_dim

        self.encoder = nn.Sequential(
            # Input dim: 3 one-hot encoded shape category + 2 locations + 2 rotations (cos, sin space) + 1 size + 3 color (RGB)
            nn.Linear(11, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
            nn.ReLU(),
        )

        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)

    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
        out = torch.cat(list(x), dim=-1)
        out = self.encoder(out)
        return self.q_mean(out), self.q_logvar(out)

And for the decoder, we will decode back to two vectors: a class vector, and an attribute one:

In [5]:
class Decoder(VAEDecoder):
    def __init__(
        self,
        in_dim: int,
        hidden_dim: int,
    ):
        super().__init__()

        self.in_dim = in_dim
        self.hidden_dim = hidden_dim

        self.decoder = nn.Sequential(
            nn.Linear(self.in_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
        )

        # Decode the categories and other attributes separately
        self.decoder_categories = nn.Sequential(
            nn.Linear(self.hidden_dim, 3),
        )

        self.decoder_attributes = nn.Sequential(
            nn.Linear(self.hidden_dim, 8),
            nn.Tanh(),
        )

    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
        out = self.decoder(x)
        return [self.decoder_categories(out), self.decoder_attributes(out)]

Now let's combine in the DomainModule:

In [6]:
class AttributeDomainModule(DomainModule):
    def __init__(
        self,
        latent_dim: int,
        hidden_dim: int,
        beta: float = 1,
        coef_categories: float = 1,
        coef_attributes: float = 1,
        optim_lr: float = 1e-3,
        optim_weight_decay: float = 0,
        scheduler_args: SchedulerArgs | None = None,
    ):
        """
        Defines the Attribute domain module.

        Args:
            latent_dim (`int`): the latent dimension of the module
            hidden_dim (`int`): hidden dimension of the VAE encoders and decoders
            beta (`float`): for beta-VAE
            coef_categories (`float`): loss coefficient attributed to the category
                (Defaults to 1.0)
            coef_attributes (`float`): loss coefficient attributed to the rest of the
                attributes (Defaults to 1.0)
            optim_lr (`float`): learning rate for the optimizer
            optim_weight_decay (`float`): weight decay for the optimizer
            scheduler_args (`SchedulerArgs | None`): Scheduler arguments
        """
        super().__init__(latent_dim)
        self.save_hyperparameters()

        self.hidden_dim = hidden_dim
        self.coef_categories = coef_categories
        self.coef_attributes = coef_attributes

        vae_encoder = Encoder(self.hidden_dim, self.latent_dim)
        vae_decoder = Decoder(self.latent_dim, self.hidden_dim)
        self.vae = VAE(vae_encoder, vae_decoder, beta)

        self.optim_lr = optim_lr
        self.optim_weight_decay = optim_weight_decay

        self.scheduler_args = SchedulerArgs(
            max_lr=optim_lr,
            total_steps=1,
        )
        self.scheduler_args.update(scheduler_args or {})

    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
        """
        Encodes the attributes into the latent representation.

        Args:
            x (`Sequence[torch.Tensor]`): list with 2 items: the shape category in a one-hot format,
                and the attribute vector.
        Returns:
            `torch.Tensor`: the encoded latent representation
        """
        return self.vae.encode(x)

    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
        """
        Decodes the latent representation to the shape category and attributes.

        Args:
            z (`torch.Tensor`): the latent representation
        Returns:
            `list[torch.Tensor]`: list with 2 items: the shape category in a one-hot format,
                and the attribute vector.
        """
        return self.vae.decode(z)

    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
        return self.decode(self.encode(x))

    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput:
        """
        The loss (in the latent space domain) is simply the MSE between predicted and target latent representations
        """
        return LossOutput(F.mse_loss(pred, target, reduction="mean"))

    # Pytorch Lightning functions
    # This part is very similar to the visual VAE

    def training_step(
        self,
        batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]],
        batch_idx: int,
    ) -> torch.Tensor:
        x = batch[frozenset(["attr"])]["attr"]
        return self.generic_step(x, "train")

    def validation_step(self, batch: Mapping[str, Sequence[torch.Tensor]], batch_idx: int) -> torch.Tensor:
        x = batch["attr"]
        return self.generic_step(x, "val")

    def generic_step(
        self,
        x: Sequence[torch.Tensor],
        mode: str = "train",
    ) -> torch.Tensor:
        x_categories, x_attributes = x[0], x[1]

        (mean, logvar), reconstruction = self.vae(x)
        reconstruction_categories = reconstruction[0]
        reconstruction_attributes = reconstruction[1]

        reconstruction_loss_categories = F.cross_entropy(
            reconstruction_categories,
            x_categories.argmax(dim=1),
            reduction="sum",
        )
        reconstruction_loss_attributes = gaussian_nll(reconstruction_attributes, torch.tensor(0), x_attributes).sum()

        reconstruction_loss = (
            self.coef_categories * reconstruction_loss_categories
            + self.coef_attributes * reconstruction_loss_attributes
        )
        kl_loss = kl_divergence_loss(mean, logvar)
        total_loss = reconstruction_loss + self.vae.beta * kl_loss

        self.log(
            f"{mode}/reconstruction_loss_categories",
            reconstruction_loss_categories,
        )
        self.log(
            f"{mode}/reconstruction_loss_attributes",
            reconstruction_loss_attributes,
        )
        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
        self.log(f"{mode}/kl_loss", kl_loss)
        self.log(f"{mode}/loss", total_loss)
        return total_loss

    def configure_optimizers(self) -> dict[str, Any]:
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.optim_lr,
            weight_decay=self.optim_weight_decay,
        )
        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": "step",
            },
        }

Now let's instanciate the domain module

In [7]:
attr_domain_module = AttributeDomainModule(
    latent_dim=config.domain_modules.attribute.latent_dim,
    hidden_dim=config.domain_modules.attribute.hidden_dim,
    beta=config.domain_modules.attribute.beta,
    coef_categories=config.domain_modules.attribute.coef_categories,
    coef_attributes=config.domain_modules.attribute.coef_attributes,
    optim_lr=config.training.optim.lr,
    optim_weight_decay=config.training.optim.weight_decay,
    scheduler_args={
        "max_lr": config.training.optim.max_lr,
        "total_steps": config.training.max_steps,
    },
)

We will use tensorboard to log the losses and reconstructed images. We can use `LogAttributesCallback` from `shimmer_ssd.logging` to
log reconstructed images of some image samples.

You can update the `train_attr.yaml` config file to change how often images will be updated on tensorboard:
```yaml
logging:
    log_val_medias_every_n_epochs: 1
    log_train_medias_every_n_epochs: 1
```

In [8]:
logger = TensorBoardLogger("logs", name="attr_model")

# Get some image samples to log in tensorboard.
val_samples = data_module.get_samples("val", 32)[frozenset(["attr"])]["attr"]
train_samples = data_module.get_samples("train", 32)[frozenset(["attr"])]["attr"]

# Create attr folder where we will save checkpoints
(config.default_root_dir / "attr").mkdir(exist_ok=True)

callbacks: list[Callback] = [
    # Will log the validation ground-truth and reconstructions during training
    LogAttributesCallback(
        val_samples,
        log_key="images/val_attr",
        mode="val",
        every_n_epochs=config.logging.log_val_medias_every_n_epochs,
        image_size=32,
        ncols=8,
    ),
    # Will log the training ground-truth and reconstructions during training
    LogAttributesCallback(
        train_samples,
        log_key="images/train_attr",
        mode="train",
        every_n_epochs=config.logging.log_train_medias_every_n_epochs,
        image_size=32,
        ncols=8,
    ),
    # Save the checkpoints
    ModelCheckpoint(
        dirpath=config.default_root_dir / "attr" / f"version_{logger.version}",
        filename="{epoch}",
        monitor="val/loss",
        mode="min",
        save_last="link",
        save_top_k=1,
    ),
]

For the final model, let's save where the model is saved:

In [9]:
attribute_checkpoint = config.default_root_dir / "attr" / f"version_{logger.version}"
print(attribute_checkpoint)

checkpoints/attr/version_1


Load tensorboard. You can select the version associated to the previous path. It will appear after the training is started.

In [10]:
%load_ext tensorboard

In [12]:
%tensorboard --logdir "./logs/attr_model" --reload_interval 30 --reload_task 'auto'

Reusing TensorBoard on port 6006 (pid 4137456), started 0:01:52 ago. (Use '!kill 4137456' to kill it.)

Let's start the training!

In [13]:
trainer = Trainer(
    logger=logger,
    fast_dev_run=config.training.fast_dev_run,
    max_steps=config.training.max_steps,
    enable_progress_bar=config.training.enable_progress_bar,
    default_root_dir=config.default_root_dir,
    callbacks=callbacks,
    precision=config.training.precision,
    accelerator=config.training.accelerator,
    devices=config.training.devices,
)

trainer.fit(attr_domain_module, data_module)
trainer.validate(attr_domain_module, data_module, "best")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type | Params | Mode 
--------------------------------------
0 | vae  | VAE  | 11.4 K | train
--------------------------------------
11.4 K    Trainable params
0         Non-trainable params
11.4 K    Total params
0.046     Total estimated model params size (MB)
22        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_steps=2000` reached.
Restoring states from the checkpoint path at /home/bdevillers/projects/shimmer-ssd-tutorials/checkpoints/attr/version_1/epoch=7.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at /home/bdevillers/projects/shimmer-ssd-tutorials/checkpoints/attr/version_1/epoch=7.ckpt


Validation: |                                                                                                 …

[{'val/reconstruction_loss_categories': 6.365734577178955,
  'val/reconstruction_loss_attributes': 15491.32421875,
  'val/reconstruction_loss': 15497.6904296875,
  'val/kl_loss': 16367.7392578125,
  'val/loss': 16316.0771484375}]

## Train a Global Workspace

Now that we trained our two unimodal modules, we will train the global workspace. For this training, we will use half of the paired 500,000 samples.
To this extent, we need to create a split in the dataset. A dataset split depends on a seed and the proportion of each group of domain.
We only need to generate this split once.

This can be done with the `shapesd alignment add` command. It needs the following arguments:
- `--dataset_path "DATASET_PATH"`: the location where the dataset is stored
- `--seed SEED` the split seed
- `--domain_alignment DOMAIN_1,DOMAIN_2,...DOMAIN_N PROP` the proportion for each domain group. This corresponds to what has been defined in `domain_proportion`

When running this command, it will create a file containing the indices of the items available in the train set (update so that it matches what we set in the config file).

In [14]:
!shapesd alignment add --dataset_path "simple_shapes_dataset" --seed 0 --domain_alignment attr 1.0 --domain_alignment v 1.0 --domain_alignment attr,v 0.5

This time, we will load the config from the extra file `train_gw.yaml` 

First, let's update `main.yaml` to use the same alignment split:
```yaml
domain_proportions: 
    -   domains: ["v"]  # unimodal visual passes use 100% of the available data 
        proportion: 1.0
    -   domains: ["attr"]  # unimodal attr passes use 100% of the available data
        proportion: 1.0
    -   domains: ["v", "attr"]  # paired passes uses 50% of the available data
        proportion: 0.5
```

let's change the selected domains:

```yaml
domains:
    - checkpoint_path: "./checkpoints/visual/version_0/last.ckpt"  # update to the actual version
      domain_type: v_latents
    - checkpoint_path: "./checkpoints/attr/version_0/last.ckpt"  # update to the actual version
      domain_type: attr
```

and let's define the global workspace dimenison to 12:
```yaml
global_workspace:
    latent_dim: 12  
    
    loss_coefficients:
        cycles: 1.0
        contrastives: 0.1
        demi_cycles: 1.0
        translations: 1.0

    encoders:
        hidden_dim: 32
        n_layers: 3

    decoders:
        hidden_dim: 32
        n_layers: 3
```

Finally, let's load the config:

In [13]:
config = load_config("./config", use_cli=False, load_files=["train_gw.yaml"])

Skip the following cell if you have trained the unimodal module yourself. The next cell setups pretrained modules.

### Run this if you did't train the modules

In [7]:
# Download checkpoints
!ssd download checkpoints
!mv checkpoints/checkpoints/* checkpoints/
!rm -rf checkpoints/checkpoints

# Extract visual latent from pretrained visual domain
!ssd extract v "checkpoints/domain_v.ckpt" -p "simple_shapes_dataset"

Downloading in checkpoints.
  0%|                                                | 0.00/288M [00:00<?, ?B/s]
  0%|                                        | 254k/288M [00:00<02:05, 2.28MB/s]
  2%|▋                                      | 4.64M/288M [00:00<00:11, 25.6MB/s]
  5%|█▉                                     | 14.4M/288M [00:00<00:04, 57.7MB/s]
  9%|███▍                                   | 25.4M/288M [00:00<00:03, 78.0MB/s]
 13%|█████                                  | 37.6M/288M [00:00<00:02, 93.8MB/s]
 17%|██████▉                                 | 49.6M/288M [00:00<00:02, 102MB/s]
 21%|████████                               | 59.9M/288M [00:00<00:04, 54.7MB/s]
 25%|█████████▋                             | 71.4M/288M [00:01<00:03, 66.7MB/s]
 29%|███████████▏                           | 82.5M/288M [00:01<00:02, 76.6MB/s]
 33%|████████████▊                          | 94.7M/288M [00:01<00:02, 86.8MB/s]
 37%|██████████████▊                         | 107M/288M [00:01<00:01, 96.0MB/s]


In [14]:
# Update the config
checkpoint_path = Path("./checkpoints")

config.domain_proportions = {
    frozenset(["v"]): 1.0,
    frozenset(["attr"]): 1.0,
    frozenset(["v", "attr"]): 0.5,
}

config.domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.attr,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
    ),
]

config.domain_data_args["v_latents"]["presaved_path"] = "domain_v.npy"
config.global_workspace.latent_dim = 12

### Load the domains and train
We can now load the pretrained unimodal modules

In [15]:
# we load the pretrained domain modules and define the associated GW encoders and decoders
domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    config.domains,
    config.global_workspace.latent_dim,
    config.global_workspace.encoders.hidden_dim,
    config.global_workspace.encoders.n_layers,
    config.global_workspace.decoders.hidden_dim,
    config.global_workspace.decoders.n_layers,
)

Instanciate the global Workspace class

In [16]:
def get_scheduler(optimizer: Optimizer) -> OneCycleLR:
    return OneCycleLR(optimizer, config.training.optim.max_lr, config.training.max_steps)


global_workspace = GlobalWorkspace2Domains(
    domain_modules,
    gw_encoders,
    gw_decoders,
    config.global_workspace.latent_dim,
    config.global_workspace.loss_coefficients,
    config.training.optim.lr,
    config.training.optim.weight_decay,
    scheduler=get_scheduler,
)

In [17]:
domain_classes = get_default_domains(["v_latents", "attr"])

data_module = SimpleShapesDataModule(
    config.dataset.path,
    domain_classes,
    config.domain_proportions,
    batch_size=config.training.batch_size,
    num_workers=config.training.num_workers,
    seed=config.seed,
    domain_args=config.domain_data_args,
)

Add a tensorboard logger to follow the training

In [18]:
logger = TensorBoardLogger("logs", name="gw")

# Get some image samples to log in tensorboard.
train_samples = data_module.get_samples("train", 32)
val_samples = data_module.get_samples("val", 32)

# split the unique group in validation into individual groups for logging
for domains in val_samples:
    for domain in domains:
        val_samples[frozenset([domain])] = {domain: val_samples[domains][domain]}
    break
# Create attr folder where we will save checkpoints
(config.default_root_dir / "gw").mkdir(exist_ok=True)

callbacks: list[Callback] = [
    # Will log the validation ground-truth and reconstructions during training
    LogGWImagesCallback(
        val_samples,
        log_key="images/val",
        mode="val",
        every_n_epochs=config.logging.log_val_medias_every_n_epochs,
        filter=config.logging.filter_images,
    ),
    # Will log the training ground-truth and reconstructions during training
    LogGWImagesCallback(
        train_samples,
        log_key="images/train",
        mode="train",
        every_n_epochs=config.logging.log_train_medias_every_n_epochs,
        filter=config.logging.filter_images,
    ),
    # Save the checkpoints
    ModelCheckpoint(
        dirpath=config.default_root_dir / "gw" / f"version_{logger.version}",
        filename="{epoch}",
        monitor="val/loss",
        mode="min",
        save_last="link",
        save_top_k=1,
    ),
]

For the final model, let's save where the model is saved:

In [19]:
gw_checkpoint = config.default_root_dir / "gw" / f"version_{logger.version}"
print(gw_checkpoint)

checkpoints/gw/version_2


Load tensorboard. You can select the version associated to the previous path. It will appear after the training is started.

In [8]:
%load_ext tensorboard

In [9]:
%tensorboard --logdir "./logs/gw" --reload_interval 30 --reload_task 'auto'

And train!

In [20]:
trainer = Trainer(
    logger=logger,
    max_steps=config.training.max_steps,
    default_root_dir=config.default_root_dir,
    callbacks=callbacks,
    precision=config.training.precision,
    accelerator=config.training.accelerator,
    devices=config.training.devices,
)

trainer.fit(global_workspace, data_module)
trainer.validate(global_workspace, data_module, "best")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type                  | Params | Mode 
----------------------------------------------------------------
0 | gw_mod        | GWModule              | 5.9 M  | train
1 | selection_mod | SingleDomainSelection | 0      | train
2 | loss_mod      | GWLosses2Domains      | 5.9 M  | train
----------------------------------------------------------------
15.9 K    Trainable params
5.8 M     Non-trainable params
5.9 M     Total params
23.428    Total estimated model params size (MB)
47        Modules in train mode
57        Modules in eval mode


Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_steps=8000` reached.
Restoring states from the checkpoint path at /home/bdevillers/projects/shimmer-ssd-tutorials/checkpoints/gw/version_2/epoch=31.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loaded model weights from the checkpoint at /home/bdevillers/projects/shimmer-ssd-tutorials/checkpoints/gw/version_2/epoch=31.ckpt


Validation: |                                                                                                 …

[{'val/demi_cycle_v_latents': 0.16235917806625366,
  'val/demi_cycle_attr': 0.03258915990591049,
  'val/demi_cycles': 0.09747415781021118,
  'val/cycle_v_latents_through_attr': 0.21616032719612122,
  'val/cycle_attr_through_v_latents': 0.11297674477100372,
  'val/cycles': 0.16456852853298187,
  'val/translation_attr_to_v_latents': 0.25512421131134033,
  'val/translation_v_latents_to_attr': 0.17463333904743195,
  'val/translations': 0.21487878262996674,
  'val/contrastive_v_latents_and_attr': 1.4367294311523438,
  'val/contrastives': 1.4367294311523438,
  'val/contrastive_v_latents_and_attr_logit_scale': 14.285717964172363,
  'val/loss': 0.15514861047267914}]

### Run this if you did't train the model

In [77]:
# Update the config
checkpoint_path = Path("./checkpoints")

config.domain_proportions = {
    frozenset(["v"]): 1.0,
    frozenset(["attr"]): 1.0,
    frozenset(["v", "attr"]): 0.5,
}

config.domains = [
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.v_latents,
        checkpoint_path=checkpoint_path / "domain_v.ckpt",
    ),
    LoadedDomainConfig(
        domain_type=DomainModuleVariant.attr,
        checkpoint_path=checkpoint_path / "domain_attr.ckpt",
    ),
]

config.domain_data_args["v_latents"]["presaved_path"] = "domain_v.npy"
config.global_workspace.latent_dim = 12
# And now we load the GW checkpoint
checkpoint_path = Path("./checkpoints")
checkpoint = checkpoint_path / "gw-attr-v-half-paired-data.ckpt"

# we load the pretrained domain modules and define the associated GW encoders and decoders
domain_modules, gw_encoders, gw_decoders = load_pretrained_domains(
    config.domains,
    config.global_workspace.latent_dim,
    config.global_workspace.encoders.hidden_dim,
    config.global_workspace.encoders.n_layers,
    config.global_workspace.decoders.hidden_dim,
    config.global_workspace.decoders.n_layers,
)

global_workspace = GlobalWorkspace2Domains.load_from_checkpoint(
    checkpoint,
    domain_mods=domain_modules,
    gw_encoders=gw_encoders,
    gw_decoders=gw_decoders,
)

## Play with the global workspace

In [21]:
import io
import math

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, interact_manual
from PIL import Image
from shimmer_ssd.logging import attribute_image_grid
from torch.nn.functional import one_hot

from simple_shapes_dataset.cli import generate_image
%matplotlib widget

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_workspace.to(device)

cat2idx = {"Diamond": 0, "Egg": 1, "Triangle": 2}


def get_image(cat, x, y, size, rot, color_r, color_g, color_b):
    fig, ax = plt.subplots(figsize=(32, 32), dpi=1)
    # The dataset generatoion tool has function to generate a matplotlib shape
    # from the attributes. 
    generate_image(
        ax,
        cat2idx[cat],
        [int(x * 18 + 7), int(y * 18 + 7)],
        size * 7 + 7,
        rot * 2 * math.pi,
        np.array([color_r * 255, color_g * 255, color_b * 255]),
        imsize=32,
    )
    ax.set_facecolor("black")
    plt.tight_layout(pad=0)
    # Return this as a PIL Image.
    # This is to have the same dpi as saved images
    # otherwise matplotlib will render this in very high quality
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    image = Image.open(buf)
    plt.close(fig)
    return image


@interact(
    cat=["Triangle", "Egg", "Diamond"],
    x=(0, 1, 0.1),
    y=(0, 1, 0.1),
    rot=(0, 1, 0.1),
    size=(0, 1, 0.1),
    color_r=(0, 1, 0.1),
    color_g=(0, 1, 0.1),
    color_b=(0, 1, 0.1),
)
def play_with_gw(
    cat: str = "Triangle",
    x: float = 0.5,
    y: float = 0.5,
    rot: float = 0.5,
    size: float = 0.5,
    color_r: float = 1,
    color_g: float = 0,
    color_b: float = 0,
):
    fig, axes = plt.subplots(1, 2)
    image = get_image(cat, x, y, size, rot, color_r, color_g, color_b)
    axes[0].set_facecolor("black")
    axes[0].set_title("Original image from attributes")
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[0].imshow(image)

    # normalize the attribute for the global workspace.
    category = one_hot(torch.tensor([cat2idx[cat]]), 3)
    rotx = math.cos(rot * 2 * math.pi)
    roty = math.sin(rot * 2 * math.pi)
    attributes = torch.tensor(
        [[x * 2 - 1, y * 2 - 1, size * 2 - 1, rotx, roty, color_r * 2 - 1, color_g * 2 - 1, color_b * 2 - 1]]
    )
    samples = [category.to(device), attributes.to(device)]
    attr_gw_latent = global_workspace.gw_mod.encode({"attr": global_workspace.encode_domain(samples, "attr")})
    gw_latent = global_workspace.gw_mod.fuse(
        attr_gw_latent, {"attr": torch.ones(attr_gw_latent["attr"].size(0)).to(device)}
    )
    decoded_latents = global_workspace.gw_mod.decode(gw_latent)["v_latents"]
    decoded_images = (
        global_workspace.domain_mods["v_latents"]
        .decode_images(decoded_latents)[0]
        .permute(1, 2, 0)
        .detach()
        .cpu()
        .numpy()
    )
    axes[1].imshow(decoded_images)
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    axes[1].set_title("Translated image through GW")
    plt.show()

interactive(children=(Dropdown(description='cat', options=('Triangle', 'Egg', 'Diamond'), value='Triangle'), F…