In [1]:
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
from model_arch.run_train import *
from utils.step_sample import create_named_schedule_sampler
from model_arch.train import TrainLoop
from utils.data import load_data_text
from model_arch.tokenizer import load_tokenizer, load_model_emb
from model_arch.sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch
from utils import dist_util
from functools import partial
import pickle
import random
from datetime import datetime

In [3]:
dist_util.clear_cache()

In [4]:
lr=0.0001
batch_size=30
val_batch_size=30
microbatch=10
epochs=30_000
eval_interval=10
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=1000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=128
hidden_t_dim=128
hidden_dim=128
dropout=0.1
seed=10275679
weight_decay=0.0
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [5]:
regular_data_dir='data'
data_player_dir='data/with_player'
comedies_data_dir='data/comedies_only'

# set the data directory
data_dir=comedies_data_dir

In [6]:
set_seed(seed)

2024-03-14 22:57:52.295966: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-14 22:57:52.296026: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-14 22:57:52.297571: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-14 22:57:52.307234: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
tokenizer = load_tokenizer('shakespeare', config_name)

shakespeare


In [8]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [9]:
model_weight

Embedding(30268, 128)

In [10]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size
vocab_size

30268

In [11]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

val = load_data_text(
        batch_size=val_batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        split='valid',
        model_emb=model_weight, # use model's weights as init
    )

############################## 
Loading text data...
############################## 
Loading dataset from data/comedies_only...
### Loading form the TRAIN set...
### Data samples...
 ['i would not have my right rosalind of this mind, for, i protest, her frown might kill me.', 'i thought, by your readiness in the office, you had continued in it some time. you say, seven years together?'] ['by this hand, it will not kill a fly. but come, now i will be your rosalind in a more coming-on disposition, and ask me what you will. i will grant it.', 'and a half, sir.']
RAM used: 1403.98 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 8706
})
RAM used: 1410.62 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/8706 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 8706
})
### tokenized_datasets...example [2, 31, 241, 125, 150, 105, 755, 2841, 94, 138, 727, 10, 115, 10, 31, 2566, 10, 175, 2191, 526, 803, 117, 12, 3]
RAM used: 1425.03 MB


merge and mask:   0%|          | 0/8706 [00:00<?, ? examples/s]

RAM used: 1441.84 MB


padding:   0%|          | 0/8706 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 8706
}) padded dataset
RAM used: 1454.34 MB
RAM used: 1454.34 MB
############################## 
Loading text data...
############################## 
Loading dataset from data/comedies_only...
### Loading form the VALID set...
### Data samples...
 ['to my petticoat, or what you will command me will i do, so well i know my duty to my elders.', "youth, thou bear'st thy father's face, frank nature, rather curious than in haste, hath well composed thee. thy father's moral parts mayst thou inherit too! welcome to paris."] ['of all thy suitors, here i charge thee, tell whom thou lovest best see thou dissemble not.', "my thanks and duty are your majesty's."]
RAM used: 1446.25 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 2167
})
RAM used: 1446.26 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/2167 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 2167
})
### tokenized_datasets...example [2, 88, 105, 10193, 10, 222, 164, 89, 159, 941, 117, 159, 31, 144, 10, 146, 254, 31, 251, 105, 1485, 88, 105, 13299, 12, 3]
RAM used: 1447.74 MB


merge and mask:   0%|          | 0/2167 [00:00<?, ? examples/s]

RAM used: 1451.43 MB


padding:   0%|          | 0/2167 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 2167
}) padded dataset
RAM used: 1462.77 MB
RAM used: 1462.77 MB


In [12]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

In [13]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

91192508

In [14]:
# from torch import nn
# from transformers import BertConfig

# config = BertConfig.from_pretrained("bert-base-uncased")

# for layer in model.input_transformers.layer[-1:]:
#     for module in layer.modules():
#         if isinstance(module, nn.Linear):
#             module.weight.data.normal_(mean=0.0, std=config.initializer_range)
#             if module.bias is not None:
#                 module.bias.data.zero_()
#         elif isinstance(module, nn.Embedding):
#             module.weight.data.normal_(mean=0.0, std=config.initializer_range)
#             if module.padding_idx is not None:
#                 module.weight.data[module.padding_idx].zero_()
#         elif isinstance(module, nn.LayerNorm):
#             module.bias.data.zero_()
#             module.weight.data.fill_(1.0)

In [15]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

model.to(dist_util.dev())

train_loss, val_loss = TrainLoop(
                            model=model,
                            diffusion=diffusion,
                            data=data,
                            batch_size=batch_size,
                            microbatch=microbatch,
                            lr=lr,
                            ema_rate=ema_rate,
                            schedule_sampler=schedule_sampler,
                            weight_decay=weight_decay,
                            epochs=epochs,
                            eval_data=val,
                            eval_interval=eval_interval,
                            warm_up_steps=500,
                            use_llrd=True,
                            llrd_rate=0.9
                        ).run_loop()

dt = datetime.now().strftime("%m%d")
pickle.dump(model, open(f"models/{dt}/final_model_df{diffusion_steps}.pkl", 'wb'))





name: word_embedding.weight, lr: 0.0001
name: lm_head.bias, lr: 0.0001
name: time_embed.0.weight, lr: 0.0001
name: time_embed.0.bias, lr: 0.0001
name: time_embed.2.weight, lr: 0.0001
name: time_embed.2.bias, lr: 0.0001
name: input_up_proj.0.weight, lr: 0.0001
name: input_up_proj.0.bias, lr: 0.0001
name: input_up_proj.2.weight, lr: 0.0001
name: input_up_proj.2.bias, lr: 0.0001
name: input_transformers.layer.0.attention.self.query.weight, lr: 0.00011111111111111112
name: input_transformers.layer.0.attention.self.query.bias, lr: 0.00011111111111111112
name: input_transformers.layer.0.attention.self.key.weight, lr: 0.00011111111111111112
name: input_transformers.layer.0.attention.self.key.bias, lr: 0.00011111111111111112
name: input_transformers.layer.0.attention.self.value.weight, lr: 0.00011111111111111112
name: input_transformers.layer.0.attention.self.value.bias, lr: 0.00011111111111111112
name: input_transformers.layer.0.attention.output.dense.weight, lr: 0.00011111111111111112
na

KeyboardInterrupt: 

In [None]:
# plot(train_loss, val_loss)

In [None]:
# param_names = []
# for i, (name, param) in enumerate(model.named_parameters()):
#     param_names.append(name)
#     print(f'{i}: {name} {param.requires_grad}')

In [None]:
best_model_fp = f'models/0308/model_best_epoch_23930_min_val_loss_0.02459999918937683.pkl'
with open(best_model_fp, 'rb') as handle:
    best_model = pickle.load(handle)

# Generating sequences

In [None]:
diffusion_2000 = SpacedDiffusion(
    betas=get_named_beta_schedule(noise_schedule, 2000),
    rescale_timesteps=rescale_timesteps,
    predict_xstart=predict_xstart,
)

In [None]:
word_lst_source, word_lst_recover, word_lst_ref, inter_lst_recover = sampling(best_model, 
                                                           diffusion_2000, 
                                                           tokenizer, 
                                                           data_dir=regular_data_dir, 
                                                           batch_size=10, 
                                                           split='test_custom', 
                                                           seq_len=128, 
                                                           show_intermediate_results=False)

Generating 20 sentences takes 5 minutes

In [None]:
word_lst_source

In [None]:
word_lst_recover