In [1]:
from afa_rl.datasets import CubeDataset, DataModuleFromDataset
from afa_rl.models import ShimEmbedderClassifier, ReadProcessEncoder, ShimEmbedder

from torch import nn
from torchrl.modules import MLP
import lightning as pl
from lightning.pytorch.loggers import WandbLogger
import torch

In [2]:
torch.set_float32_matmul_precision("medium")

In [3]:
n_features = 20
dataset = CubeDataset(n_features=n_features, data_points=100_000, sigma=0.01, seed=42)
datamodule = DataModuleFromDataset(dataset=dataset, batch_size=128, train_ratio=0.8)

In [4]:
# Check what the dataset looks like
dataset[1]

(tensor([ 0.1313,  0.9204,  0.4699,  0.2398, -0.0065,  0.0108,  0.9922,  0.6261,
          0.3093,  0.5771,  0.6224,  0.6519,  0.3125,  0.3840,  0.4065,  0.3761,
          0.8397,  0.8010,  0.8765,  0.9403]),
 tensor([0., 0., 0., 0., 1., 0., 0., 0.]))

In [5]:
from afa_rl.models import MLPClassifier


encoder = ReadProcessEncoder(
    feature_size=n_features + 1,  # state contains one value and one index
    output_size=16,
    reading_block_cells=[32, 32],
    writing_block_cells=[32, 32],
    memory_size=16,
    processing_steps=5,
)
embedder = ShimEmbedder(encoder)
classifier = MLPClassifier(16, 8, [32, 32])
model = ShimEmbedderClassifier(embedder=embedder, classifier=classifier, lr=1e-4)

In [6]:
logger = WandbLogger(project="pretrain-shim-embedder-classifier", save_dir="logs")
trainer = pl.Trainer(
    max_epochs=1,
    logger=logger,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
)
trainer.fit(model, datamodule)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvalterschutz[0m ([33mvalterschutz-chalmers-university-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type          | Params | Mode 
-----------------------------------------------------
0 | embedder   | ShimEmbedder  | 8.8 K  | train
1 | classifier | MLPClassifier | 1.9 K  | train
-----------------------------------------------------
10.7 K    Trainable params
0         Non-trainable params
10.7 K    Total params
0.043     Total estimated model params size (MB)
24        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/valter/Documents/Projects/afa-rl/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.


                                                                           

/home/valter/Documents/Projects/afa-rl/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=21` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 625/625 [00:12<00:00, 51.41it/s, v_num=ysyo]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 625/625 [00:12<00:00, 51.35it/s, v_num=ysyo]


In [7]:
# Do some predictions
test_dataset = CubeDataset(n_features=n_features, data_points=10, sigma=0.01, seed=42)
X_test, y_test = test_dataset[:]
_, y_pred = model(X_test, torch.ones_like(X_test, dtype=torch.bool))
print(y_test.argmax(dim=1))
print(y_pred.argmax(dim=1))

tensor([2, 4, 0, 1, 4, 6, 5, 6, 7, 7])
tensor([0, 0, 0, 0, 6, 7, 7, 7, 7, 7])
