# Run Distribution Learning Benchmark

In [1]:
%cd ..

/home/adam/Projects/hybrid-transformer


In [3]:
import torch
import wandb

from hybrid_transformer.configs.task import TaskConfig
from hybrid_transformer.configs.model import ModelConfig
from hybrid_transformer.configs.trainer import TrainerConfig
from hybrid_transformer.configs.logger import LoggerConfig

from hybrid_transformer.utils.datasets.auto import AutoDataset
from hybrid_transformer.utils.tokenizers.auto import AutoTokenizer
from hybrid_transformer.models.auto import AutoModel
from hybrid_transformer.utils.loggers.wandb import WandbLogger

from hybrid_transformer.trainers.trainer import Trainer

from scripts.train import DEFAULT_CONFIG_FILES

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
task_config = TaskConfig.from_pretrained(DEFAULT_CONFIG_FILES['task'])
model_config = ModelConfig.from_pretrained(DEFAULT_CONFIG_FILES['model'])
# model_config.model_type = 'HybridTransformer'
# model_config.task_p = 0.95
trainer_config = TrainerConfig.from_pretrained(DEFAULT_CONFIG_FILES['trainer'])
logger_config = LoggerConfig.from_pretrained(DEFAULT_CONFIG_FILES['logger'])
task_config.validate = False
logger_config.wandb_log = False

You are using a model of type GPT to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


In [24]:
dataset = AutoDataset.from_config(task_config)
tokenizer = AutoTokenizer.from_config(task_config)
model = AutoModel.from_config(model_config)
logger = WandbLogger(logger_config, [task_config, model_config, trainer_config])
trainer = Trainer(config=trainer_config, model=model, train_dataset=dataset, eval_dataset=dataset, tokenizer=tokenizer, logger=logger)
# trainer.load_checkpoint()
trainer.compile = False
trainer._train_init()

AttributeError: 'ModelConfig' object has no attribute 'model_name'

In [8]:
trainer.train()

num decayed parameter tensors: 63, with 38,115,840 parameters
num non-decayed parameter tensors: 25, with 12,800 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)


AttributeError: 'function' object has no attribute 'train'

In [18]:
trainer.model

OptimizedModule(
  (_orig_mod): GPT(
    (transformer): ModuleDict(
      (wte): Embedding(588, 512)
      (wpe): Embedding(128, 512)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x HybridTransformerBlock(
          (ln_1): LayerNorm()
          (attn_1): HybridSelfAttention(
            (q_proj): Linear(in_features=512, out_features=512, bias=False)
            (kv_proj): Linear(in_features=512, out_features=1024, bias=False)
            (out_proj): Linear(in_features=512, out_features=512, bias=False)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm()
          (mlp): MLP(
            (fc): Linear(in_features=512, out_features=2048, bias=False)
            (gelu): GELU(approximate='none')
            (proj): Linear(in_features=2048, out_features=512, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
  

In [17]:
model.task_p

1.0

In [19]:
trainer.is_ddp_run

False

In [22]:
model_config.to_dict()

{'model_type': '',
 'embedding_dim': 512,
 'num_heads': 8,
 'num_layers': 12,
 'bias': False,
 'dropout': 0.1,
 'layer_norm_eps': 1e-05,
 'block_size': 1024,
 'vocab_size': 588,
 'max_seq_len': 128,
 'return_dict': True,
 'output_hidden_states': False,
 'output_attentions': False,
 'torchscript': False,
 'torch_dtype': None,
 'use_bfloat16': False,
 'tf_legacy_loss': False,
 'pruned_heads': {},
 'tie_word_embeddings': True,
 'is_encoder_decoder': False,
 'is_decoder': False,
 'cross_attention_hidden_size': None,
 'add_cross_attention': False,
 'tie_encoder_decoder': False,
 'max_length': 20,
 'min_length': 0,
 'do_sample': False,
 'early_stopping': False,
 'num_beams': 1,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'typical_p': 1.0,
 'repetition_penalty': 1.0,
 'length_penalty': 1.0,
 'no_repeat_ngram_size': 0,
 'encoder_no_repeat_ngram_size': 0,
 'bad_words_ids': None,
 'num_return_sequences': 1,
 'chunk_size_feed_forward': 0,
 