In [1]:
import sys
import os

sys.path.append('./common_lib')

In [2]:
from munch import Munch  # Munch is a dictionary that supports attribute-style access

config_names = [
    "mini",
    # "tiny",
    # 'small',
    # 'medium',
    # 'large',
    # 'XL',
]


def add_exp_name(config):
    """Constructs the name of the log folder used to easily identify the experiment."""
    c = config
    c.exp_name = "{}_{}_{}_sl{}_h{}_ff{}_nH{}_dH{}_nl{}_seed{}{}{}".format(
        c.model,
        f"_bl{c.block_length}",
        c.dataset,
        c.seq_len,
        c.h_dim,
        c.mlp_dim,
        c.n_heads,
        c.head_dim,
        c.n_layers,
        c.seed,
        f"_{c.comment}" if c.comment else "",
        "_debug" if c.debug else "",
    )


## Add experiment configs
def load_config(name=None):

    c = Munch(
        # # data
        # data_root = "data/books",
        relative_log_path="logs",  # Relative path to the log folder within the project folder
        # dataset = "books_16384",
        # vocab_size = 16384,
        debug=False,  # simply adds a "_debug" suffix so logs are easily distinguishable
        # # optimiser
        seed=41,
        # gradient_accumulation_steps = 1,    # number of batches before doing a gradient step
        train_batch_size=32,  # make sure batch sizes are an integer multiple of the number of workers
        eval_batch_size=32,
        test_batch_size=32,
        # seq_len = 512,
        # max_eval_steps = 512,
        # max_train_steps = 500_000,          # total number of training steps
        # decay_steps = 500_000,              # number of steps over which we will decay the learning rate
        # max_lr = 0.0006,                    # starting learning rate
        # min_lr = 0.000006,                  # final learning rate
        # grad_clip_norm = 0.0,               # gradient norm clipping
        # tokens_per_second = 0,              # tokens per second throughput of this config on the hardware run; used for logging over gpuhours
        # # perform certain tasks every N steps
        # eval_every = 1_000,                 # perform a fast evaluation (validation data)
        # test_every = -1,                    # perform a thorough evaluation (test data)
        # log_terminal_every = 100,           # print the current loss to terminal
        # log_metrics_every = 100,            # log accuracy and loss metrics
        # log_grads_every = 1_000,            # log gradients and step sizes
        # log_activations_every = -1,         # log gradients and step sizes
        log_ckpt_every=1_000,  # save model checkpoint to disk
        # logging
        comment="",
        logger_type="wandb",  # can be 'tb', 'wandb' or 'all'
        wandb_project_name="qlstm",
        dataset="dyck", #["bit_parity", "dyck", "mqar"]
        model="lstm", #["lstm", "qlstm", "gpt"]
        project_name="modern_rnns",
        max_steps = 40000,
        use_flash = False
    )
    # default model
    if not name or name == "default":
        name = "mini"

    # model
    if name == "mini":
        c.n_layers = 2
        c.h_dim = 4
        c.mlp_dim = 8
        c.head_dim = 4
        c.n_heads = 4
        c.block_length = 8 # keep it equal to seq len for faster convergence

        # Dataset config
        c.output_size = 2
        c.num_input_classes = 2

        # Dyck specific
        c.depth = 6
        c.num_parentheses = 3
        c.seq_len = 8

        #MQAR specific
        c.n_keys = 3
        c.n_values = 6
        c.train_num_pairs = "3,3"
        c.eval_num_pairs = "3,3"
        c.max_num_pairs = 3
        c.unique_keys = True
        c.all_queries_for_input = False

        # Bit parity specific
        c.train_seq_len = "8,8"
        c.eval_seq_len = "8,8"
        c.max_seq_len = 8
    else:
        raise ValueError(f"Config name {name} is an invalid name. ")

    return c

In [3]:
from bit_parity_dataset import BitParityDatasetIterator
from dyck_dataset import DyckDatasetIterator
from mqar_dataset import MQARDatasetIterator

def construct_dataset(config):
    if config.dataset == "bit_parity":
        train_ds = BitParityDatasetIterator(
            batch_size=config.train_batch_size,
            sequence_length=config.train_seq_len,
            pad_sequence_length=config.max_seq_len,
            device=config.device,
        )
        eval_ds = BitParityDatasetIterator(
            batch_size=config.eval_batch_size,
            sequence_length=config.eval_seq_len,
            pad_sequence_length=config.max_seq_len,
            device=config.device,
        )
    elif config.dataset == "dyck":
        train_ds = DyckDatasetIterator(
            batch_size=config.train_batch_size,
            sequence_length=config.train_seq_len,
            pad_sequence_length=config.max_seq_len,
            device=config.device,
            depth=config.depth,
            num_parentheses=config.num_parentheses,
        )
        eval_ds = DyckDatasetIterator(
            batch_size=config.eval_batch_size,
            sequence_length=config.eval_seq_len,
            pad_sequence_length=config.max_seq_len,
            device=config.device,
            depth=config.depth,
            num_parentheses=config.num_parentheses,
        )
        config.num_input_classes = config.num_parentheses * 2 + 2
    elif config.dataset == "mqar":
        train_ds = MQARDatasetIterator(
            batch_size=config.train_batch_size,
            num_pairs=config.train_num_pairs,
            n_keys=config.n_keys,
            n_values=config.n_values,
            pad_num_pairs=config.max_num_pairs,
            unique_keys=config.unique_keys,
            all_queries_for_input=config.all_queries_for_input,
            device=config.device,
        )
        eval_ds = MQARDatasetIterator(
            batch_size=config.eval_batch_size,
            num_pairs=config.eval_num_pairs,
            n_keys=config.n_keys,
            n_values=config.n_values,
            pad_num_pairs=config.max_num_pairs,
            unique_keys=config.unique_keys,
            all_queries_for_input=config.all_queries_for_input,
            device=config.device,
        )
        config.num_input_classes = max(config.n_keys, config.n_values + 1) + 1
        config.output_size = config.n_values + 1
        config.max_seq_len = max(config.max_num_pairs * 3, config.max_seq_len)
    else:
        raise RuntimeError(
            f"Dataset {config.dataset} not supported. Please add the configuration for this dataset."
        )
        
    return train_ds, eval_ds

In [4]:
from trainers.bit_parity_trainer import BitParityTrainer
from trainers.dyck_trainer import DyckTrainer
from trainers.mqar_trainer import MQARTrainer

def construct_trainerClass(config):
    if config.dataset == "bit_parity":
        trainer = BitParityTrainer
    elif config.dataset == "dyck":
        trainer = DyckTrainer
    elif config.dataset == "mqar":
        trainer.dataset = MQARTrainer
    else:
        raise RuntimeError(
            f"Dataset {config.dataset} not supported. Please add the configuration for this dataset."
        )

    return trainer

In [5]:
from projects.gpt.modelgpt import ModelGPT
from projects.lstm.basic_lstm_model import ModelLSTM
from projects.qlstm.modelqlstm import ModelQLSTM

def construct_model(config):
    if config.model == "gpt":
        model = ModelGPT(config = config)
    elif config.model == "lstm":
        model = ModelLSTM(config = config)
    elif config.model == "qlstm":
        model = ModelQLSTM(config = config)
    else:
        raise RuntimeError(
            f"Model {config.model} not supported. Please add the configuration for this dataset."
        )

    return model

In [6]:
import torch

def set_up_seeds(config):
    torch.manual_seed(config.seed)

In [7]:
def run(config):
    
    set_up_seeds(config)
    train_ds, eval_ds = construct_dataset(config)
    trainerClass = construct_trainerClass(config)
    model = construct_model(config).to(config.device)

    opt = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.95), eps=1e-08)
    logger = experiment_utils.setup_experiment(config)
    
    trainer = trainerClass(
        config=config,
        model=model,
        train_loader=train_ds,
        eval_loader=eval_ds,
        optimizer=opt,
        device=config.device,
        logger=logger,
    )

    trainer.train()

In [8]:
def run_wrapper(config):
    import sys, os
    sys.path.append('./common_lib')

    if torch.cuda.is_available():
        config.device = torch.device("cuda")
    else:
        mprint("Cuda is not available, using CPU instead.")
        config.device = torch.device("cpu")
    config.project_path = os.path.join(os.getcwd(), 'projects', config.model)
    add_exp_name(config)
    print(config.dataset, config.model, config.train_seq_len, config.eval_seq_len)
    config.run_name = f"{config.model}_{config.dataset}_par_{config.num_parentheses}_depth_{config.depth}_{config.train_seq_len}_{config.eval_seq_len}"
    
    from dotenv import load_dotenv
    load_dotenv()
    wandb.login(key=os.getenv("WANDB_API_KEY"))
    wandb.init(project=config.project_name, entity="www-vickyzeu", config=config, name = config.run_name)
    return run(config)

In [11]:
import submitit
import wandb
from common_lib import experiment_utils

train_seq_lens = [f"{x},{x}" for x in [2 ** i for i in range(4, 9)]]
eval_seq_lens = [f"{x},{x}" for x in [2 ** i + 2 ** (i - 1) for i in range(4, 9)]]
models = ["lstm", "gpt", "qlstm"]

def train(local=False):
    config = load_config()
    print('key', os.getenv("WANDB_API_KEY"))
    if local:
        from dotenv import load_dotenv
        load_dotenv()
        wandb.login(key=os.getenv("WANDB_API_KEY"))

        run(config)
    else:
        executor = submitit.AutoExecutor(folder="logs/slurm")
        executor.update_parameters(
            timeout_min=60*4,
            tasks_per_node=1,
            #cpus_per_task=4,
            account="pmlr",
            name="pmlr_training"
        )
        jobs = []
        for train_seq_len, eval_seq_len in zip(train_seq_lens[1:4], eval_seq_lens[1:4]):
            for model in models[:1]:
                config.train_seq_len = train_seq_len
                config.eval_seq_len = eval_seq_len
                config.max_seq_len = max(int(train_seq_len.split(',')[-1]), int(eval_seq_len.split(',')[-1]))
                config.model = model
                config.block_length = min(8, config.max_seq_len)
                jobs.append(executor.submit(run_wrapper, config))
                
        print(f"Submitted {len(jobs)} jobs.")

In [12]:
from dotenv import load_dotenv
import os
env_path = os.path.join('.', '.env')  # won't work in Jupyter
load_dotenv()
train()

key 3585267dd9c6986318c56af0efac3ab16956dabc


sbatch: error: GPU count: 1
sbatch: error: CPU count: 2
sbatch: error: Memory: 24576MB
sbatch: error: QOSMaxSubmitJobPerUserLimit
sbatch: error: Batch job submission failed: Job violates accounting/QOS policy (job submit limit, user's size and/or time limits)


FailedJobError: sbatch: error: GPU count: 1
sbatch: error: CPU count: 2
sbatch: error: Memory: 24576MB
sbatch: error: QOSMaxSubmitJobPerUserLimit
sbatch: error: Batch job submission failed: Job violates accounting/QOS policy (job submit limit, user's size and/or time limits)