In [73]:
from run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from train import TrainLoop
from utils.data import load_data_text
from tokenizer import load_tokenizer, load_model_emb
from sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch, os
from utils import dist_util
from functools import partial
import pickle
import random
from datetime import datetime

In [74]:
dist_util.clear_cache()
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [75]:
lr=0.00001
batch_size=20
microbatch=5
epochs=30_000
eval_interval=100
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=2000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=128
hidden_t_dim=128
hidden_dim=128
dropout=0.1
seed=102
weight_decay=0.0
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [76]:
cc_data_dir='data/commonsense'
ss_data_dir='data/shakespeare'
ss_small_data_dir='data/mini-shakespeare'
combined_data_dir='data/combined'
combined_small_data_dir='data/combined/small'
regular_data_dir='data'

# set the data directory
data_dir=regular_data_dir

In [77]:
set_seed(seed)

In [78]:
tokenizer = load_tokenizer('shakespeare_plays', config_name)

In [79]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [80]:
model_weight

Embedding(30267, 128)

In [81]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size
vocab_size

30267

In [82]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

val = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        split='valid',
        model_emb=model_weight # use model's weights as init
    )

############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the TRAIN set...
### Data samples...
 ['o hell! what have we here? a carrion death, within whose empty eye there is a written scroll!', 'and his disciples only envy at, ye blew the fire that burns ye now have at ye! enter king,'] ["i'll read the writing. all that glitters is not gold, often have you heard that told", 'frowning on them, takes his seat']
RAM used: 3832.46 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 48627
})
RAM used: 3844.47 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/48627 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 48627
})
### tokenized_datasets...example [2, 36, 1299, 5, 163, 149, 132, 236, 21, 22, 7134, 431, 9, 905, 568, 3065, 755, 209, 120, 22, 4179, 7421, 5, 3]
RAM used: 3825.27 MB


merge and mask:   0%|          | 0/48627 [00:00<?, ? examples/s]

RAM used: 3855.66 MB


padding:   0%|          | 0/48627 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 48627
}) padded dataset
RAM used: 3945.82 MB
RAM used: 3945.82 MB
############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the VALID set...
### Data samples...
 ["petruchio is my name, antonio's son, a man well known throughout all italy.", 'the matter is to me, sir, as concerning jaquenetta. the manner of it is,'] ['i know him well you are welcome for his sake.', 'i was taken with the manner.']
RAM used: 3827.22 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 12147
})
RAM used: 3827.22 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/12147 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 12147
})
### tokenized_datasets...example [2, 3885, 120, 104, 519, 9, 2545, 8, 40, 477, 9, 22, 210, 253, 1232, 9839, 186, 4042, 11, 3]
RAM used: 3827.13 MB


merge and mask:   0%|          | 0/12147 [00:00<?, ? examples/s]

RAM used: 3838.37 MB


padding:   0%|          | 0/12147 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 12147
}) padded dataset
RAM used: 3856.27 MB
RAM used: 3856.27 MB


In [83]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

model.to(dist_util.dev())

TransformerNetModel(
  (word_embedding): Embedding(30267, 128)
  (lm_head): Linear(in_features=128, out_features=30267, bias=True)
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=768, bias=True)
  )
  (input_up_proj): Sequential(
    (0): Linear(in_features=128, out_features=768, bias=True)
    (1): Tanh()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_transformers): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768,

In [84]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

91192379

In [85]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
        eval_data=val,
        eval_interval=eval_interval,
        warm_up_steps=500,
        llrd_rate=0.999
    ).run_loop()





name: word_embedding.weight, lr: 1e-05
name: lm_head.bias, lr: 1e-05
name: time_embed.0.weight, lr: 1e-05
name: time_embed.0.bias, lr: 1e-05
name: time_embed.2.weight, lr: 1e-05
name: time_embed.2.bias, lr: 1e-05
name: input_up_proj.0.weight, lr: 1e-05
name: input_up_proj.0.bias, lr: 1e-05
name: input_up_proj.2.weight, lr: 1e-05
name: input_up_proj.2.bias, lr: 1e-05
name: input_transformers.layer.0.attention.self.query.weight, lr: 1.3333333333333335e-05
name: input_transformers.layer.0.attention.self.query.bias, lr: 1.3333333333333335e-05
name: input_transformers.layer.0.attention.self.key.weight, lr: 1.3333333333333335e-05
name: input_transformers.layer.0.attention.self.key.bias, lr: 1.3333333333333335e-05
name: input_transformers.layer.0.attention.self.value.weight, lr: 1.3333333333333335e-05
name: input_transformers.layer.0.attention.self.value.bias, lr: 1.3333333333333335e-05
name: input_transformers.layer.0.attention.output.dense.weight, lr: 1.3333333333333335e-05
name: input_

In [None]:
# param_names = []
# for i, (name, param) in enumerate(model.named_parameters()):
#     param_names.append(name)
#     print(f'{i}: {name} {param.requires_grad}')

In [None]:
# dt = datetime.now().strftime("%m%d")
# best_model_fp = f'models/0221/model_best_epoch_20600_min_val_loss_0.028699999675154686.pkl'
# with open(best_model_fp, 'rb') as handle:
#     best_model = pickle.load(handle)

# Generating sequences

In [None]:
_, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        2000,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

In [None]:
word_lst_source, word_lst_recover, word_lst_ref = sampling(model, 
                                                           diffusion, 
                                                           tokenizer, 
                                                           data_dir=regular_data_dir, 
                                                           batch_size=10, 
                                                           split='test_custom', 
                                                           seq_len=128)

Generating 20 sentences takes 5 minutes

In [None]:
word_lst_source

In [None]:
word_lst_recover

In [None]:
word_lst_ref