In [None]:
import sys
import os
import torch

sys.path.append('./common_lib')

def set_up_seeds(config):
    torch.manual_seed(config.seed)

WANDB_ENTITY = "www-vickyzeu"

# Set which sequence lengths and models you want to run here.

# These were our experiments evaluating on much longer sequences than training.
train_seq_lens = ["32,32", "32,32", "32,32"]
eval_seq_lens = ["64,64", "128,128", "256,256"]

# There were our experiments evaluating on 50% longer sequence lengths as training.
# train_seq_lens = [f"{x},{x}" for x in [2 ** i for i in range(5, 9)]]
# eval_seq_lens = [f"{x},{x}" for x in [2 ** i + 2 ** (i - 1) for i in range(5, 9)]]

models = ["gpt", "lstm", "qlstm", "lin_transformer", "lru", "delta_net", "mamba"]

# Set which dataset to run
DATASET = "bit_parity" # "bit_parity", "dyck", "mqar"

In [None]:
from munch import Munch  # Munch is a dictionary that supports attribute-style access

config_names = [
    "mini",
]


def add_exp_name(config):
    """Constructs the name of the log folder used to easily identify the experiment."""
    c = config
    c.exp_name = "{}_{}_{}_sl{}_h{}_ff{}_nH{}_dH{}_nl{}_seed{}{}{}".format(
        c.model,
        f"_bl{c.block_length}",
        c.dataset,
        c.seq_len,
        c.h_dim,
        c.mlp_dim,
        c.n_heads,
        c.head_dim,
        c.n_layers,
        c.seed,
        f"_{c.comment}" if c.comment else "",
        "_debug" if c.debug else "",
    )


## Add experiment configs
def load_config(name=None):

    c = Munch(
        # # data
        relative_log_path="logs",  # Relative path to the log folder within the project folder
        debug=False,  # simply adds a "_debug" suffix so logs are easily distinguishable
        # # optimiser
        seed=41,
        # gradient_accumulation_steps = 1,    # number of batches before doing a gradient step
        train_batch_size=32,  # make sure batch sizes are an integer multiple of the number of workers
        eval_batch_size=32,
        test_batch_size=32,
        # seq_len = 512,
        # max_eval_steps = 512,
        # max_train_steps = 500_000,          # total number of training steps
        # decay_steps = 500_000,              # number of steps over which we will decay the learning rate
        # max_lr = 0.0006,                    # starting learning rate
        # min_lr = 0.000006,                  # final learning rate
        # grad_clip_norm = 0.0,               # gradient norm clipping
        # tokens_per_second = 0,              # tokens per second throughput of this config on the hardware run; used for logging over gpuhours
        # # perform certain tasks every N steps
        # eval_every = 1_000,                 # perform a fast evaluation (validation data)
        # test_every = -1,                    # perform a thorough evaluation (test data)
        # log_terminal_every = 100,           # print the current loss to terminal
        # log_metrics_every = 100,            # log accuracy and loss metrics
        # log_grads_every = 1_000,            # log gradients and step sizes
        # log_activations_every = -1,         # log gradients and step sizes
        log_ckpt_every=1_000,  # save model checkpoint to disk
        # logging
        comment="",
        logger_type="wandb",  # can be 'tb', 'wandb' or 'all'
        wandb_project_name="associative_rnns",
        dataset=DATASET, #["bit_parity", "dyck", "mqar"]
        project_name="associative_rnns",
        project_path="./projects/lstm/", # Hack, just used to save the source code
        exp_name="",
        max_steps = 40000,
        use_flash = False,
        device = "cuda"
    )
    # default model
    if not name or name == "default":
        name = "mini"

    # model
    if name == "mini":
        c.n_layers = 2
        c.h_dim = 4
        c.mlp_dim = 8
        c.head_dim = 4
        c.n_heads = 4
        c.block_length = 8 # keep it equal to seq len for faster convergence

        # Mamba
        c.d_model = 128
        c.d_state = 8
        c.d_conv = 3
        c.expand = 2
        c.dt_rank = 1

        # Dataset config
        c.output_size = 2
        c.num_input_classes = 2

        # Dyck specific
        c.depth = 6
        c.num_parentheses = 3
        c.seq_len = 8

        #MQAR specific
        c.n_keys = 3
        c.n_values = 6
        c.train_num_pairs = "3,3"
        c.eval_num_pairs = "3,3"
        c.max_num_pairs = 3
        c.unique_keys = True
        c.all_queries_for_input = False

        # Bit parity specific
        c.train_seq_len = "8,8"
        c.eval_seq_len = "8,8"
        c.num_ones = 12 # Fixed number of 1s in the training sequences. Use None to disable this
        c.max_seq_len = 8
    else:
        raise ValueError(f"Config name {name} is an invalid name. ")

    return c

In [None]:
from datasets.bit_parity_dataset import BitParityDatasetIterator
from datasets.dyck_dataset import DyckDatasetIterator
from datasets.mqar_dataset import MQARDatasetIterator

def construct_dataset(config):
    if config.dataset == "bit_parity":
        train_ds = BitParityDatasetIterator(
            batch_size=config.train_batch_size,
            sequence_length=config.train_seq_len,
            num_ones=config.num_ones,
            device=config.device,
        )
        eval_ds = BitParityDatasetIterator(
            batch_size=config.eval_batch_size,
            sequence_length=config.eval_seq_len,
            num_ones=config.num_ones,
            device=config.device,
        )
    elif config.dataset == "dyck":
        train_ds = DyckDatasetIterator(
            batch_size=config.train_batch_size,
            sequence_length=config.train_seq_len,
            device=config.device,
            depth=config.depth,
            num_parentheses=config.num_parentheses,
        )
        eval_ds = DyckDatasetIterator(
            batch_size=config.eval_batch_size,
            sequence_length=config.eval_seq_len,
            device=config.device,
            depth=config.depth,
            num_parentheses=config.num_parentheses,
        )
        config.num_input_classes = config.num_parentheses * 2 + 2
    elif config.dataset == "mqar":
        train_ds = MQARDatasetIterator(
            batch_size=config.train_batch_size,
            num_pairs=config.train_num_pairs,
            n_keys=config.n_keys,
            n_values=config.n_values,
            unique_keys=config.unique_keys,
            all_queries_for_input=config.all_queries_for_input,
            device=config.device,
        )
        eval_ds = MQARDatasetIterator(
            batch_size=config.eval_batch_size,
            num_pairs=config.eval_num_pairs,
            n_keys=config.n_keys,
            n_values=config.n_values,
            unique_keys=config.unique_keys,
            all_queries_for_input=config.all_queries_for_input,
            device=config.device,
        )
        config.num_input_classes = max(config.n_keys, config.n_values + 1) + 1
        config.output_size = config.n_values + 1
        config.max_seq_len = max(config.max_num_pairs * 3, config.max_seq_len)
    else:
        raise RuntimeError(
            f"Dataset {config.dataset} not supported. Please add the configuration for this dataset."
        )

    return train_ds, eval_ds

In [None]:
from trainers.bit_parity_trainer import BitParityTrainer
from trainers.dyck_trainer import DyckTrainer
from trainers.mqar_trainer import MQARTrainer

def construct_trainerClass(config):
    if config.dataset == "bit_parity":
        trainer = BitParityTrainer
    elif config.dataset == "dyck":
        trainer = DyckTrainer
    elif config.dataset == "mqar":
        trainer = MQARTrainer
    else:
        raise RuntimeError(
            f"Dataset {config.dataset} not supported. Please add the configuration for this dataset."
        )

    return trainer

In [None]:
from projects.gpt.modelgpt import ModelGPT
from projects.lstm.modellstm import ModelLSTM
from projects.qlstm.modelqlstm import ModelQLSTM
from projects.linearTransformer.modelLinTransformer import ModelLinTransformer
from projects.lru.modellru import ModelLRU
from projects.deltanet.modelDeltanet import ModelDeltaNet
from projects.mamba.modelmamba import ModelMamba

def construct_model(config):
    if config.model == "gpt":
        model = ModelGPT(config = config)
    elif config.model == "lstm":
        model = ModelLSTM(config = config)
    elif config.model == "qlstm":
        model = ModelQLSTM(config = config)
    elif config.model == "lin_transformer":
        model = ModelLinTransformer(config = config)
    elif config.model == "lru":
        model = ModelLRU(config = config)
    elif config.model == "delta_net":
        model = ModelDeltaNet(config = config)
    elif config.model == "mamba":
        model = ModelMamba(config = config)
    else:
        raise RuntimeError(
            f"Model {config.model} not supported. Please add the configuration for this dataset."
        )

    return model

In [None]:
def run(config):

    set_up_seeds(config)
    train_ds, eval_ds = construct_dataset(config)
    trainerClass = construct_trainerClass(config)
    model = construct_model(config).to(config.device)

    opt = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.95), eps=1e-08)
    logger = experiment_utils.setup_experiment(config)

    trainer = trainerClass(
        config=config,
        model=model,
        train_loader=train_ds,
        eval_loader=eval_ds,
        optimizer=opt,
        device=config.device,
        logger=logger,
    )

    trainer.train()

In [None]:
def get_config_run_name(config):
    if config.dataset == "dyck":
        return f"{config.model}_{config.dataset}_par_{config.num_parentheses}_depth_{config.depth}_{config.train_seq_len}_{config.eval_seq_len}"
    elif config.dataset == "bit_parity":
        ones_string = f"_ones_{config.num_ones}" if config.num_ones else ""
        return f"{config.model}_{config.dataset}{ones_string}_{config.train_seq_len}_{config.eval_seq_len}"
    elif config.dataset == "mqar":
        return f"{config.model}_{config.dataset}_keys_{config.n_keys}_values_{config.n_values}_pairs_{config.train_num_pairs}_{config.eval_num_pairs}"
    else:
        raise RuntimeError(
            f"No run name pattern found for {config.dataset}, please add one."
        )


In [None]:
def run_wrapper(config):
    import sys, os
    sys.path.append('./common_lib')

    if torch.cuda.is_available():
        config.device = torch.device("cuda")
    else:
        mprint("Cuda is not available, using CPU instead.")
        config.device = torch.device("cpu")
    config.project_path = os.path.join(os.getcwd(), 'projects', config.model)
    add_exp_name(config)
    print(config.dataset, config.model, config.train_seq_len, config.eval_seq_len)
    config.run_name = get_config_run_name(config)


    from dotenv import load_dotenv
    load_dotenv()
    wandb.login(key=os.getenv("WANDB_API_KEY"))
    wandb.init(project=config.project_name, entity=WANDB_ENTITY, config=config, name = config.run_name)
    return run(config)

In [None]:
import submitit
import wandb
from common_lib import experiment_utils

def train(local=False):
    config = load_config()
    if local:

        for train_seq_len, eval_seq_len in zip(train_seq_lens, eval_seq_lens):
            for model in models:
                config.train_seq_len = train_seq_len
                config.eval_seq_len = eval_seq_len
                config.max_seq_len = max(int(train_seq_len.split(',')[-1]), int(eval_seq_len.split(',')[-1]))
                config.model = model
                config.block_length = min(8, config.max_seq_len)
                config.exp_name = f"{model}_sl{train_seq_len}_esl{eval_seq_len}_bl{config.block_length}"
                config.run_name = get_config_run_name(config)

                from dotenv import load_dotenv
                load_dotenv()
                wandb.init(project=config.project_name, entity=WANDB_ENTITY, config=config, name = config.run_name)
                run(config)
                wandb.finish()
    else:
        executor = submitit.AutoExecutor(folder="logs/slurm")
        executor.update_parameters(
            timeout_min=60*6,
            tasks_per_node=1,
            #cpus_per_task=4,
            account="pmlr_jobs",
            name="pmlr_training"
        )
        jobs = []
        for train_seq_len, eval_seq_len in zip(train_seq_lens, eval_seq_lens):
            for model in models:
                config.train_seq_len = train_seq_len
                config.eval_seq_len = eval_seq_len
                config.max_seq_len = max(int(train_seq_len.split(',')[-1]), int(eval_seq_len.split(',')[-1]))
                config.model = model
                config.block_length = min(8, config.max_seq_len)
                jobs.append(executor.submit(run_wrapper, config))

        print(f"Submitted {len(jobs)} jobs.")

In [None]:
from dotenv import load_dotenv
import os
env_path = os.path.join('.', '.env')  # won't work in Jupyter
load_dotenv()
train(local=True)