In [1]:
from flash_ansr.models import SetTransformer, PreEncoder
from flash_ansr import get_path, FlashANSRDataset, ExpressionSpace
from flash_ansr.train.loss import ContrastiveLoss

from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
N_VARIABLES = 11
OUTPUT_SIZE = 512

In [3]:
pre_encoder = PreEncoder(
    input_size=N_VARIABLES + 1,
    mode="ieee-754",
    support_nan=False,
    exponent_scale=None)

In [4]:
model_config = {
    "hidden_size": 512,
    "n_enc_isab": 5,
    "n_dec_sab": 2,
    "n_induce": 64,
    "n_heads": 8,
    "layer_norm": False,
    "n_seeds": 64,
    "input_embedding_size": pre_encoder.encoding_size,
    "input_dimension_size": pre_encoder.input_size,
    "output_embedding_size": OUTPUT_SIZE,
}

In [5]:
print(f'{SetTransformer(**model_config).n_params:,}')

13,625,856


In [6]:
class SetTransformerWrapper(nn.Module):
    def __init__(self, expression_space: ExpressionSpace, set_transformer: SetTransformer, pre_encoder: PreEncoder):
        super().__init__()
        self.expression_space = expression_space
        self.pre_encoder = pre_encoder
        self.set_transformer = set_transformer
        set_transformer_output_size = set_transformer.output_embedding_size * set_transformer.n_seeds
        self.token_head = nn.Linear(set_transformer_output_size, len(expression_space.tokenizer.vocab))
        self.complexity_head = nn.Linear(set_transformer_output_size, 1)
        self.n_constants_head = nn.Linear(set_transformer_output_size, 1)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, M, D = x.size()
        data_pre_encodings = self.pre_encoder(x)
        data_pre_encodings = data_pre_encodings.view(B, M, D, self.pre_encoder.encoding_size)
        x = self.set_transformer(data_pre_encodings)
        B, S, D = x.size()
        x = x.view(B, S * D)
        token_logits = self.token_head(x)
        complexity = self.complexity_head(x)
        n_constants = self.n_constants_head(x)
        return token_logits, complexity, n_constants
    
    @property
    def n_params(self):
        return sum(p.numel() for p in self.parameters())

In [7]:
dataset_train = FlashANSRDataset.from_config(get_path("configs", "v7.0", "dataset_train.yaml"))

Compiling Skeletons: 100%|██████████| 200/200 [00:00<00:00, 48839.12it/s]
Compiling Skeletons: 100%|██████████| 43/43 [00:00<00:00, 37985.48it/s]
Compiling Skeletons: 100%|██████████| 10/10 [00:00<00:00, 29066.56it/s]
Compiling Skeletons: 100%|██████████| 4999/4999 [00:00<00:00, 30367.23it/s]
Compiling Skeletons: 100%|██████████| 5000/5000 [00:00<00:00, 39489.70it/s]
Compiling Skeletons: 100%|██████████| 200/200 [00:00<00:00, 53220.45it/s]
Compiling Skeletons: 100%|██████████| 43/43 [00:00<00:00, 45201.77it/s]
Compiling Skeletons: 100%|██████████| 10/10 [00:00<00:00, 28591.03it/s]
Compiling Skeletons: 100%|██████████| 4999/4999 [00:00<00:00, 30463.01it/s]


In [8]:
dataset_val = FlashANSRDataset.from_config(get_path("configs", "v7.0", "dataset_val.yaml"))

Compiling Skeletons: 100%|██████████| 200/200 [00:00<00:00, 53261.00it/s]
Compiling Skeletons: 100%|██████████| 43/43 [00:00<00:00, 42416.53it/s]
Compiling Skeletons: 100%|██████████| 10/10 [00:00<00:00, 30218.33it/s]
Compiling Skeletons: 100%|██████████| 4999/4999 [00:00<00:00, 29678.92it/s]


In [9]:
def create_targets(batch: dict, device: torch.device) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # Complexity targets by counting the number of non-padding tokens in each instance
    complexity_targets = (batch["input_ids"] != 0).sum(dim=1).unsqueeze(1).float() - 2  # Subtract 2 to account for the start and end tokens

    # Number of constants targets by counting the number of '<num>' tokens in each instance
    n_constants_targets = (batch["input_ids"] == dataset_train.expression_space.tokenizer["<num>"]).sum(dim=1).unsqueeze(1).float()

    # Binary token targets that depict for each token in the vocabulary whether it is present in the instance
    token_targets = torch.zeros(batch["input_ids"].shape[0], len(dataset_train.expression_space.tokenizer.vocab), device=device)
    for i, tokens in enumerate(batch["input_ids"]):
        token_targets[i, tokens] = 1

    return token_targets, complexity_targets, n_constants_targets

In [10]:
train_config = {
    "n_per_equation": 1,
    "batch_size": 128,
    "steps": 1_000_000,
    "val_steps": 100,
    "val_every_steps": 1_000,
    "lr": 1e-4,
}

loss_fn_token = nn.BCEWithLogitsLoss()
loss_fn_complexity = nn.MSELoss()
loss_fn_n_constants = nn.MSELoss()

model = SetTransformerWrapper(
    set_transformer=SetTransformer(**model_config),
    expression_space=dataset_train.expression_space,
    pre_encoder=pre_encoder).to(device)
optimizer = optim.AdamW(model.parameters(), lr=train_config['lr'], amsgrad=True)
n_params = model.n_params
flops_per_token = 6 * n_params

cumulative_training_pflops = 0

try:
    with wandb.init(config=train_config | model_config, project="foundation_set_encoder", entity="psaegert", name=f'set-transformer-combined'):
        pbar = tqdm(total=train_config['steps'], smoothing=0)
        model.train()
        for b, batch in enumerate(dataset_train.iterate(steps=train_config['steps'], batch_size=train_config['batch_size'], n_per_equation=train_config['n_per_equation'])):
            optimizer.zero_grad()
            batch = dataset_train.collate(batch, device)

            # Pad the x_tensor with zeros to match the expected maximum input dimension of the set transformer
            pad_length = N_VARIABLES - batch["x_tensors"].shape[2]
            if pad_length > 0:
                x_tensor = nn.functional.pad(batch["x_tensors"], (0, pad_length, 0, 0, 0, 0), value=0)
            else:
                x_tensor = batch["x_tensors"]

            # Concatenate x and y tensors as input to the set transformer
            data_tensor = torch.cat([x_tensor, batch["y_tensors"]], dim=-1)

            # Targets
            token_targets, complexity_targets, n_constants_targets = create_targets(batch, device)

            # Forward pass
            logits, complexity, n_const = model(data_tensor)
            
            loss_token: torch.Tensor = loss_fn_token(logits, token_targets)
            loss_complexity: torch.Tensor = loss_fn_complexity(complexity, complexity_targets)
            loss_n_constants: torch.Tensor = loss_fn_n_constants(n_const, n_constants_targets)

            loss: torch.Tensor = loss_token + loss_complexity + loss_n_constants

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            cumulative_training_pflops += (flops_per_token * batch["x_tensors"].shape[0] * batch["x_tensors"].shape[1]) * 1e-15

            wandb.log({
                "train_loss": loss.item(),
                "train_token_loss": loss_token.item(),
                "train_complexity_loss": loss_complexity.item(),
                "train_n_constants_loss": loss_n_constants.item(),
                "cumulative_flops": cumulative_training_pflops,
            }, step=b)

            if (b + 1) % train_config["val_every_steps"] == 0:
                model.eval()
                val_losses = []
                val_token_losses = []
                val_complexity_losses = []
                val_n_constants_losses = []
                with torch.no_grad():
                    for val_batch in dataset_val.iterate(steps=train_config['val_steps'], batch_size=train_config["batch_size"], n_per_equation=train_config["n_per_equation"]):
                        val_batch = dataset_val.collate(val_batch, device)

                        pad_length = N_VARIABLES - val_batch["x_tensors"].shape[2]
                        if pad_length > 0:
                            x_tensor = nn.functional.pad(val_batch["x_tensors"], (0, pad_length, 0, 0, 0, 0), value=0)
                        else:
                            x_tensor = val_batch["x_tensors"]

                        data_tensor = torch.cat([x_tensor, val_batch["y_tensors"]], dim=-1)
                        
                        token_targets, complexity_targets, n_constants_targets = create_targets(val_batch, device)

                        logits, complexity, n_const = model(data_tensor)

                        loss_token: torch.Tensor = loss_fn_token(logits, token_targets)
                        loss_complexity: torch.Tensor = loss_fn_complexity(complexity, complexity_targets)
                        loss_n_constants: torch.Tensor = loss_fn_n_constants(n_const, n_constants_targets)

                        loss: torch.Tensor = loss_token + loss_complexity + loss_n_constants

                        val_losses.append(loss.item())
                        val_token_losses.append(loss_token.item())
                        val_complexity_losses.append(loss_complexity.item())
                        val_n_constants_losses.append(loss_n_constants.item())

                wandb.log({
                    "val_loss": np.mean(val_losses),
                    "val_token_loss": np.mean(val_token_losses),
                    "val_complexity_loss": np.mean(val_complexity_losses),
                    "val_n_constants_loss": np.mean(val_n_constants_losses),
                    }, step=b)

                model.train()
    
            pbar.update()

        pbar.close()
    
except KeyboardInterrupt:
    print("Interrupted training. Attempting to save model.")

model.save(get_path("models", "ansr-models", "set_transformer", f"v7.0", create=True))

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpsaegert[0m. Use [1m`wandb login --relogin`[0m to force relogin


  4%|▎         | 35999/1000000 [2:40:39<71:42:00,  3.73it/s]Traceback (most recent call last):
  File "/tmp/ipykernel_20585/1588366544.py", line 75, in <module>
    for val_batch in dataset_val.iterate(steps=train_config['val_steps'], batch_size=train_config["batch_size"], n_per_equation=train_config["n_per_equation"]):
  File "/home/psaegert/Projects/flash-ansr/src/flash_ansr/data.py", line 315, in iterate
    yield from self.generate_batch(batch_size=batch_size, size=size, steps=steps, n_support=n_support, n_per_equation=n_per_equation, tqdm_total=tqdm_total, verbose=verbose, avoid_fragmentation=avoid_fragmentation)
  File "/home/psaegert/Projects/flash-ansr/src/flash_ansr/data.py", line 393, in generate_batch
    for instance in self.generate(
  File "/home/psaegert/Projects/flash-ansr/src/flash_ansr/data.py", line 458, in generate
    skeleton_hash, skeleton_code, skeleton_constants = self.skeleton_pool.sample_skeleton()
                                                       ^^^^^^

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
cumulative_flops,▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
train_complexity_loss,█▄▂▂▃▂▄▂▂▂▃▂▃▂▃▃▂▃▁▂▃▁▃▁▃▁▄▅▁▃▄▃▂▁▃▁▂▁▂▂
train_loss,▆▇▄▃▄▃▃▃▂▇▄▇▄▄▄▃▂▆▅▃█▄▆▁▃▂▄▂▄▇▃▂▃▄▃▄▆▄▃▄
train_n_constants_loss,▅▄▃▃▄▇▃█▅▇▂▄▇▆▇▃▇▃▃▅▆█▂▃▃▃▆▄▄▁▅▁▄▃▄▄▃▃▄▅
train_token_loss,█▅▅▅▇▄▄▄▄▂▃▃▅▄▄▄▂▃▁▅▂▁▁▂▂▂▃▄▄▅▄▄▅▃▂▁▃▃▄▃
val_complexity_loss,█▃▄▄▇▅▅▃▃▃▂▃▃▃▃▂▂▂▂▂▄▂▁▃▂▂▂▃▃▃▁▂▂▁▂
val_contrastive_loss,█▃▄▄▇▅▅▄▃▃▃▃▃▃▃▂▂▂▂▂▄▂▁▃▂▂▂▃▃▃▁▂▂▁▂
val_n_constants_loss,▇█▆█▆▅▆▇▅▅▆▄█▄▄█▃▄▄▄▅▄▁▄▃▂▄▃▃▄▁▂▄▂▃
val_token_loss,█▅▆▄▆▄▄▄▃▃▃▂▂▂▂▂▂▃▃▁▂▂▂▁▁▁▂▁▂▁▁▁▁▂▁

0,1
cumulative_flops,107.52222
train_complexity_loss,10.14577
train_loss,11.1266
train_n_constants_loss,0.76068
train_token_loss,0.22014
val_complexity_loss,10.53826
val_contrastive_loss,11.70493
val_n_constants_loss,0.94174
val_token_loss,0.22493


Interrupted training. Attempting to save model.


AttributeError: 'SetTransformerWrapper' object has no attribute 'save'

In [49]:
for b, batch in enumerate(dataset_train.iterate(steps=train_config['steps'], batch_size=train_config['batch_size'], n_per_equation=train_config['n_per_equation'])):
    batch = dataset_train.collate(batch, device)
    target_tokens, target_complexity, target_n_constants = create_targets(batch, device)
    break

In [50]:
logits, complexity, n_constants = model(data_tensor)

In [57]:
batch['input_ids'][1]

tensor([ 1, 10, 12, 15, 30,  8, 12, 11, 24, 31,  6,  6,  2,  0,  0,  0,  0,  0,
         0,  0,  0,  0], device='cuda:0')

In [61]:
for i, logit in enumerate(torch.softmax(logits[1], dim=-1)):
    print(f'{i} {logit.log10().item():.1f}')

0 -4.1
1 -0.3
2 -0.3
3 -12.7
4 -12.7
5 -12.7
6 -6.2
7 -6.2
8 -6.7
9 -7.4
10 -5.8
11 -6.3
12 -7.2
13 -7.3
14 -6.5
15 -7.0
16 -7.0
17 -7.1
18 -8.6
19 -8.9
20 -9.1
21 -9.2
22 -7.5
23 -7.6
24 -7.5
25 -8.9
26 -9.1
27 -7.7
28 -6.7
29 -8.5
30 -3.3
31 -3.1
32 -8.9


In [54]:
for target_complexity_item, predicted_complexity in zip(target_complexity[:5], complexity):
    print(f'{target_complexity_item.item()} {predicted_complexity.item()}')

15.0 12.875956535339355
11.0 10.489505767822266
6.0 14.327173233032227
14.0 14.180545806884766
10.0 14.659640312194824


In [55]:
for target_n_constants_item, predicted_n_constants in zip(target_n_constants[:5], n_constants):
    print(f'{target_n_constants_item.item()} {predicted_n_constants.item()}')

0.0 1.1526859998703003
2.0 0.9796116948127747
0.0 1.2337887287139893
2.0 0.9209513664245605
1.0 1.6110610961914062
