# Testing the enconder-only model

In this model, we use the k-sparse parity dataset (deterministic) to test the encoder-only Transformer.

**Report**: 

- For $n=4$, $k=2$, encoder got $100\%$ using the parameters below. With the same parameters, the decoder got $100\%$ as well.
- For $n=40$, $k=4$, with the same parameters, the encoder got $52\%$
- New experiments: increasing the size of the data, the architecture, and the number of epochs.



In [1]:
import torch
from ray import tune

from mindreadingautobots.sequence_generators import make_datasets
from mindreadingautobots.models import decoder_transformer, hyperparameters

# DATA LOADING
seed = 334
n_train = 20000
n_data = int(n_train * 5/4) # downstream we have a 80/20 train/val split
n = 40
k = 4
p_bitflip = 0.0
raw_data = make_datasets.sparse_parity_k_n(n, k, n_data, p_bitflip)

config = {"epochs": 40,
        "batch_size": 32,
        "device": torch.device("mps" if torch.backends.mps.is_available() else "cpu"), # NOTE: this is only for mac. For windows use cuda instead of mps.
        "lr": 1e-3,
        "context_size": 500,
        "vocab_size": 2,
        "n_layer": 4,
        "n_head": 4,
        "d_model": 16,
        "dropout": 0.0,
        "d_ff": 128,
        "activation": "relu",
        "standard_positional_encoding": False,
        "loss_type": "cross_entropy",
        "bias": True,
        "tie_weights": False,
        "embedding": "embedding",
        "mode": "encoder"
        }

2024-07-22 15:42:20,497	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-07-22 15:42:20,839	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
model, train_dataloader, val_dataloader = decoder_transformer.train_loop(config, raw_data, verbose=True, 
                                                                         return_model=True, return_data=True)

  X = torch.tensor(sample[:-shift_step])
  y = torch.tensor(sample[shift_step:])


Epoch 0, train loss: 0.17602567699849606, val loss: 0.02089956253292454
Epoch 1, train loss: 0.01883792527318001, val loss: 0.018057114652292743
Epoch 2, train loss: 0.017991875100135803, val loss: 0.01799545001689416
Epoch 3, train loss: 0.01791250951886177, val loss: 0.017914207533571373
Epoch 4, train loss: 0.020131711181998253, val loss: 0.01804883519460441
Epoch 5, train loss: 0.017911969393491746, val loss: 0.017978901518093553
Epoch 6, train loss: 0.017857039558887483, val loss: 0.017970495270031275
Epoch 7, train loss: 0.01785421078503132, val loss: 0.01788157969713211
Epoch 8, train loss: 0.017914129328727723, val loss: 0.017958788141892973
Epoch 9, train loss: 0.01784638808965683, val loss: 0.018296154429483565
Epoch 10, train loss: 0.018843798154592514, val loss: 0.01786307674969078
Epoch 11, train loss: 0.017843603318929674, val loss: 0.017902180825354188
Epoch 12, train loss: 0.017826000770926477, val loss: 0.017802598951443746
Epoch 13, train loss: 0.01784628832936287, va

In [3]:
def generate_next_token(model, token_seq, config, max_new_tokens=1):
    model.eval()
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at context_size
        token_seq_cond = token_seq if token_seq.size(1) <= config["context_size"] else token_seq[:, -config["context_size"]:]
        # forward the model to get the logits for the index in the sequence
        logits, _ = model(token_seq_cond)
        token_seq_next = torch.argmax(logits, dim=-1)
        # append sampled index to the running sequence and continue
        token_seq = torch.cat((token_seq, token_seq_next), dim=1)

    return token_seq
    

In [4]:
correct, total = 0, 0
for batch in val_dataloader:
    X, y = batch
    X = X.to(config["device"])
    y = y.to(config["device"])
    y_pred = generate_next_token(model, X, config, max_new_tokens=1)
    correct += (y_pred[:,-1] == y[:,-1]).sum().item()
    total += y.shape[0]
    
print("Number of samples on the validation dataset:", len(val_dataloader.dataset))
print(f"Last token accuracy: {round(100*correct/total, 2)} %")

Number of samples on the validation dataset: 5000
Last token accuracy: 100.0 %
