In [2]:
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
from model_arch.run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from model_arch.train import TrainLoop
from utils.data import load_data_text
from model_arch.tokenizer import load_tokenizer, load_model_emb
from model_arch.sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch
from utils import dist_util
from functools import partial
import pickle
import random
from datetime import datetime

In [4]:
dist_util.clear_cache()

In [5]:
lr=0.0001
batch_size=20
microbatch=10
epochs=30_000
eval_interval=100
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=2000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=128
hidden_t_dim=128
hidden_dim=128
dropout=0.1
seed=102
weight_decay=0.0
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [6]:
regular_data_dir='data'
data_player_dir='data/with_player'

# set the data directory
data_dir=regular_data_dir

In [7]:
set_seed(seed)

In [8]:
tokenizer = load_tokenizer('shakespeare', config_name)

shakespeare


In [9]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [10]:
model_weight

Embedding(30268, 128)

In [11]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size
vocab_size

30268

In [12]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

val = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        split='valid',
        model_emb=model_weight # use model's weights as init
    )

############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the TRAIN set...
### Data samples...
 ['o hell! what have we here? a carrion death, within whose empty eye there is a written scroll!', 'and his disciples only envy at, ye blew the fire that burns ye now have at ye! enter king,'] ["i'll read the writing. all that glitters is not gold, often have you heard that told", 'frowning on them, takes his seat']
RAM used: 672.36 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 48627
})
RAM used: 691.26 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/48627 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 48627
})
### tokenized_datasets...example [2, 37, 1300, 6, 164, 150, 133, 237, 22, 23, 7135, 432, 10, 906, 569, 3066, 756, 210, 121, 23, 4180, 7422, 6, 3]
RAM used: 734.46 MB


merge and mask:   0%|          | 0/48627 [00:00<?, ? examples/s]

RAM used: 774.05 MB


padding:   0%|          | 0/48627 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 48627
}) padded dataset
RAM used: 868.59 MB
RAM used: 868.59 MB
############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the VALID set...
### Data samples...
 ["petruchio is my name, antonio's son, a man well known throughout all italy.", 'the matter is to me, sir, as concerning jaquenetta. the manner of it is,'] ['i know him well you are welcome for his sake.', 'i was taken with the manner.']
RAM used: 829.69 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 12147
})
RAM used: 829.69 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/12147 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 12147
})
### tokenized_datasets...example [2, 3886, 121, 105, 520, 10, 2546, 9, 41, 478, 10, 23, 211, 254, 1233, 9840, 187, 4043, 12, 3]
RAM used: 842.63 MB


merge and mask:   0%|          | 0/12147 [00:00<?, ? examples/s]

RAM used: 851.73 MB


padding:   0%|          | 0/12147 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 12147
}) padded dataset
RAM used: 874.86 MB
RAM used: 874.86 MB


In [13]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

model.to(dist_util.dev())

TransformerNetModel(
  (word_embedding): Embedding(30268, 128)
  (lm_head): Linear(in_features=128, out_features=30268, bias=True)
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=768, bias=True)
  )
  (input_up_proj): Sequential(
    (0): Linear(in_features=128, out_features=768, bias=True)
    (1): Tanh()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_transformers): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768,

In [14]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

91192508

In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
        eval_data=val,
        eval_interval=eval_interval,
        warm_up_steps=500,
        use_llrd=True,
        llrd_rate=0.99
    ).run_loop()





name: word_embedding.weight, lr: 0.0001
name: lm_head.bias, lr: 0.0001
name: time_embed.0.weight, lr: 0.0001
name: time_embed.0.bias, lr: 0.0001
name: time_embed.2.weight, lr: 0.0001
name: time_embed.2.bias, lr: 0.0001
name: input_up_proj.0.weight, lr: 0.0001
name: input_up_proj.0.bias, lr: 0.0001
name: input_up_proj.2.weight, lr: 0.0001
name: input_up_proj.2.bias, lr: 0.0001
name: input_transformers.layer.0.attention.self.query.weight, lr: 0.00010101010101010101
name: input_transformers.layer.0.attention.self.query.bias, lr: 0.00010101010101010101
name: input_transformers.layer.0.attention.self.key.weight, lr: 0.00010101010101010101
name: input_transformers.layer.0.attention.self.key.bias, lr: 0.00010101010101010101
name: input_transformers.layer.0.attention.self.value.weight, lr: 0.00010101010101010101
name: input_transformers.layer.0.attention.self.value.bias, lr: 0.00010101010101010101
name: input_transformers.layer.0.attention.output.dense.weight, lr: 0.00010101010101010101
na

Epoch 1/30000 Training Loss: 1.1961716413497925
Epoch 2/30000 Training Loss: 1.1991360187530518
Epoch 3/30000 Training Loss: 1.1946220397949219
Epoch 4/30000 Training Loss: 1.1927778720855713
Epoch 5/30000 Training Loss: 1.1906840801239014
Epoch 6/30000 Training Loss: 1.1799726486206055
Epoch 7/30000 Training Loss: 1.1712337732315063
Epoch 8/30000 Training Loss: 1.164445400238037
Epoch 9/30000 Training Loss: 1.1491776704788208
Epoch 10/30000 Training Loss: 1.1362547874450684
Epoch 11/30000 Training Loss: 1.1196528673171997
Epoch 12/30000 Training Loss: 1.1070992946624756
Epoch 13/30000 Training Loss: 1.0855627059936523
Epoch 14/30000 Training Loss: 1.0640943050384521
Epoch 15/30000 Training Loss: 1.0497028827667236
Epoch 16/30000 Training Loss: 1.0301544666290283
Epoch 17/30000 Training Loss: 1.003342866897583
Epoch 18/30000 Training Loss: 0.9832772016525269
Epoch 19/30000 Training Loss: 0.9728710651397705
Epoch 20/30000 Training Loss: 0.9377175569534302
Epoch 21/30000 Training Loss: 0

In [None]:
# param_names = []
# for i, (name, param) in enumerate(model.named_parameters()):
#     param_names.append(name)
#     print(f'{i}: {name} {param.requires_grad}')

In [None]:
# dt = datetime.now().strftime("%m%d")
# best_model_fp = f'models/0221/model_best_epoch_20600_min_val_loss_0.028699999675154686.pkl'
# with open(best_model_fp, 'rb') as handle:
#     best_model = pickle.load(handle)

# Generating sequences

In [None]:
_, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        2000,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

In [None]:
word_lst_source, word_lst_recover, word_lst_ref = sampling(model, 
                                                           diffusion, 
                                                           tokenizer, 
                                                           data_dir=regular_data_dir, 
                                                           batch_size=10, 
                                                           split='test_custom', 
                                                           seq_len=128)

Generating 20 sentences takes 5 minutes

In [None]:
word_lst_source

In [None]:
word_lst_recover

In [None]:
word_lst_ref