In [None]:
# pip install datasets --user

In [None]:
from run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from train import TrainLoop
from utils.data import load_data_text
from tokenizer import load_tokenizer, load_model_emb
from sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch, os
from utils import dist_util
from functools import partial
import pickle
import random
from datetime import datetime

In [None]:
# with open('vocab_list.pickle', 'rb') as handle:
#     vocab_list = pickle.load(handle)
# vocab_list = list(vocab_list.values())

In [None]:
dist_util.clear_cache()

In [None]:
lr=0.0001
batch_size=10
microbatch=5
epochs=20_000
# epochs=3
eval_interval=100
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=1_000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=128
hidden_t_dim=256
hidden_dim=256
dropout=0.1
seed=102
weight_decay=0.0001
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [None]:
cc_data_dir='data/commonsense'
ss_data_dir='data/shakespeare'
ss_small_data_dir='data/mini-shakespeare'
combined_data_dir='data/combined'
combined_small_data_dir='data/combined/small'
regular_data_dir='data'

# set the data directory
data_dir=regular_data_dir

In [None]:
set_seed(seed)

In [None]:
tokenizer = load_tokenizer('shakespeare_plays', config_name)

In [None]:
tokenizer.vocab_size

In [None]:
tokenizer.encode_token('find we a time for fright peace to pant')

In [None]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [None]:
model_weight

In [None]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size

In [None]:
vocab_size

Passed in as batch in TrainLoop - this is the batch data

In [None]:
# next(data)[0].shape # batch_size, seq_len, hidden_dim

In [None]:
# next(data)[0]

Passed in as cond in TrainLoop - this is a dictionary of input_ids and input_mask

In [None]:
# next(data)[1]

In [None]:
# next(data)[1]['input_ids'].shape # batch_size, hidden_dim

In [None]:
# next(data)[1]['input_mask'].shape # batch_size, hidden_dim

In [None]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

In [None]:
model.to(dist_util.dev())

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters())

In [None]:
pytorch_total_params

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

val = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        split='valid',
        model_emb=model_weight # use model's weights as init
    )

In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
        eval_data=val,
        eval_interval=eval_interval
    ).run_loop()

In [None]:
# dt = datetime.now().strftime("%m%d_%I%M%p")
# deets = f"diff_steps_{diffusion_steps}_epochs_{epochs}"
# pickle.dump(model, open(f"models/model_{dt}_{deets}.pkl", 'wb'))
# pickle.dump(diffusion, open(f"models/diffusion_{dt}_{deets}.pkl", 'wb'))

In [None]:
# with open('models/transfer_learning/model_0213_1014AM_diff_steps_1500_epochs_10000.pkl', 'rb') as handle:
#     model = pickle.load(handle)

In [None]:
# with open('models/transfer_learning/diffusion_0213_1014AM_diff_steps_1500_epochs_10000.pkl', 'rb') as handle:
#     diffusion = pickle.load(handle)

# (Only for transfer learning) Finetuning on only Shakespeare

In [None]:
for param in model.parameters():
    param.requires_grad = False

# model.lm_head0 = nn.Linear(hidden_dim, hidden_dim)
# model.lm_head0.weight.requires_grad_(True)
model.lm_head = torch.nn.Linear(hidden_dim, vocab_size)
model.lm_head.weight.requires_grad_(True)

## Check if we have enabled/disabled grad

In [None]:
model.word_embedding.weight.requires_grad

In [None]:
model.lm_head.weight.requires_grad

In [None]:
model.eval().to(dist_util.dev())

model_emb = torch.nn.Embedding(
        num_embeddings=tokenizer.vocab_size, 
        embedding_dim=hidden_dim, 
        _weight=model.word_embedding.weight.clone().cpu()
    ).eval()

In [None]:
ss_data = load_data_text(
        batch_size=10,
        seq_len=seq_len,
        data_dir=regular_data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_emb.cpu() # use model's weights as init
    )

In [None]:
model.train() # TURNING THE TRAIN MODE BACK ON TO ENABLE BATCHNORM/DROPOUT!!

In [None]:
model.word_embedding.weight.requires_grad_(True)

In [None]:
model.word_embedding.weight.requires_grad

In [None]:
model.lm_head.weight.requires_grad

In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=ss_data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
#         eval_data=data_valid,
        eval_interval=eval_interval
    ).run_loop()

In [None]:
# dt = datetime.now().strftime("%m%d_%I%M%p")
# deets = f"diff_steps_{diffusion_steps}_epochs_{epochs}"
# pickle.dump(model, open(f"models/transfer_learning/model_{dt}_{deets}.pkl", 'wb'))
# pickle.dump(diffusion, open(f"models/transfer_learning/diffusion_{dt}_{deets}.pkl", 'wb'))

## Comparing with model before transfer learning

In [None]:
# with open('models/transfer_learning/diffusion_0213_0609AM_diff_steps_1500_epochs_2.pkl', 'rb') as handle:
#     diffusion_ori = pickle.load(handle)
    
# with open('models/transfer_learning/model_0213_0609AM_diff_steps_1500_epochs_2.pkl', 'rb') as handle:
#     model_ori = pickle.load(handle)

In [None]:
# model_ori.lm_head.weight == model.lm_head.weight

In [None]:
# model_ori.word_embedding.weight == model.word_embedding.weight

In [None]:
# model_ori.input_up_proj[0].weight == model.input_up_proj[0].weight

### Nice! Worked

# Generating sequences

In [None]:
word_lst_source, word_lst_recover, word_lst_ref = sampling(model, diffusion, tokenizer, data_dir=regular_data_dir, batch_size=10, split='test_custom', seq_len=20)

Generating 20 sentences takes 5 minutes

In [None]:
word_lst_source

In [None]:
word_lst_recover

In [None]:
word_lst_ref