In [None]:
# pip install datasets --user

In [1]:
from run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from train import TrainLoop
from utils.data import load_data_text
from tokenizer import load_tokenizer, load_model_emb
from sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch, os
from utils import dist_util
from functools import partial
import pickle
import random
from datetime import datetime

2024-02-16 20:02:31.438532: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# with open('vocab_list.pickle', 'rb') as handle:
#     vocab_list = pickle.load(handle)
# vocab_list = list(vocab_list.values())

In [3]:
dist_util.clear_cache()

In [4]:
lr=0.0001
batch_size=20
microbatch=10
epochs=15_000
eval_interval=1000
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=1_000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=128
hidden_t_dim=300
hidden_dim=300
dropout=0.1
seed=102
weight_decay=0.0001
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [5]:
cc_data_dir='data/commonsense'
ss_data_dir='data/shakespeare'
ss_small_data_dir='data/mini-shakespeare'
combined_data_dir='data/combined'
combined_small_data_dir='data/combined/small'
regular_data_dir='data'

# set the data directory
data_dir=regular_data_dir

In [6]:
set_seed(seed)

In [7]:
tokenizer = load_tokenizer('shakespeare_plays', config_name)

In [8]:
tokenizer.vocab_size

33747

In [9]:
tokenizer.encode_token('find we a time for fright peace to pant')

[2, 604, 134, 24, 413, 119, 2987, 640, 90, 7166, 3]

In [10]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [11]:
model_weight

Embedding(33747, 300)

In [12]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size

In [13]:
vocab_size

33747

Passed in as batch in TrainLoop - this is the batch data

In [14]:
# next(data)[0].shape # batch_size, seq_len, hidden_dim

In [15]:
# next(data)[0]

Passed in as cond in TrainLoop - this is a dictionary of input_ids and input_mask

In [16]:
# next(data)[1]

In [17]:
# next(data)[1]['input_ids'].shape # batch_size, hidden_dim

In [18]:
# next(data)[1]['input_mask'].shape # batch_size, hidden_dim

In [19]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [20]:
model.to(dist_util.dev())

TransformerNetModel(
  (word_embedding): Embedding(33747, 300)
  (lm_head): Linear(in_features=300, out_features=33747, bias=True)
  (time_embed): Sequential(
    (0): Linear(in_features=300, out_features=1200, bias=True)
    (1): SiLU()
    (2): Linear(in_features=1200, out_features=768, bias=True)
  )
  (input_up_proj): Sequential(
    (0): Linear(in_features=300, out_features=768, bias=True)
    (1): Tanh()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_transformers): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=76

In [21]:
pytorch_total_params = sum(p.numel() for p in model.parameters())

In [22]:
pytorch_total_params

98533683

In [23]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [24]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the TRAIN set...
### Data samples...
 ["something of moment then i will go meet him there's matter in't indeed, if he be angry. i prithee, do so.", 'conscience, which is, indeed, sir, a mender of bad soles. what trade, thou knave? thou naughty knave, what trade? nay, i beseech you, sir, be not out with me yet,'] ['exit iago something, sure, of state, either from venice, or some unhatched practise', 'if you be out, sir, i can mend you. what meanest thou by that? mend me, thou saucy fellow! why, sir, cobble you.']
RAM used: 2670.82 MB
This is raw_datasets:  Dataset({
    features: ['src', 'trg'],
    num_rows: 107264
})
RAM used: 2709.53 MB


Running tokenizer on dataset (num_proc=4):   0%|          | 0/107264 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 107264
})
### tokenized_datasets...example [2, 1340, 96, 4074, 265, 32, 163, 187, 791, 156, 213, 7, 42, 811, 105, 7, 43, 744, 10, 197, 108, 104, 2085, 11, 32, 1185, 10, 147, 151, 11, 3]
RAM used: 2766.52 MB


merge and mask:   0%|          | 0/107264 [00:00<?, ? examples/s]

RAM used: 2878.09 MB


padding:   0%|          | 0/107264 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 107264
}) padded dataset
RAM used: 3088.38 MB
RAM used: 3064.38 MB


In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
#         eval_data=data_valid,
        eval_interval=eval_interval
    ).run_loop()





Epoch 0/15000 Loss: 1.0451031923294067
Epoch 1/15000 Loss: 0.869488000869751
Epoch 2/15000 Loss: 0.781876802444458
Epoch 3/15000 Loss: 0.7459273934364319
Epoch 4/15000 Loss: 0.7409336566925049
Epoch 5/15000 Loss: 0.7291581630706787
Epoch 6/15000 Loss: 0.721569299697876
Epoch 7/15000 Loss: 0.6996433138847351
Epoch 8/15000 Loss: 0.7033931016921997
Epoch 9/15000 Loss: 0.7397572994232178
Epoch 10/15000 Loss: 0.7111833691596985
Epoch 11/15000 Loss: 0.7138040065765381
Epoch 12/15000 Loss: 0.7373813390731812
Epoch 13/15000 Loss: 0.7302639484405518
Epoch 14/15000 Loss: 0.7417792081832886
Epoch 15/15000 Loss: 0.685096263885498
Epoch 16/15000 Loss: 0.711911141872406
Epoch 17/15000 Loss: 0.7061750888824463
Epoch 18/15000 Loss: 0.7069834470748901
Epoch 19/15000 Loss: 0.7299482226371765
Epoch 20/15000 Loss: 0.6960338354110718
Epoch 21/15000 Loss: 0.7062292695045471
Epoch 22/15000 Loss: 0.7168128490447998
Epoch 23/15000 Loss: 0.7048308253288269
Epoch 24/15000 Loss: 0.7262097597122192
Epoch 25/15

In [None]:
# dt = datetime.now().strftime("%m%d_%I%M%p")
# deets = f"diff_steps_{diffusion_steps}_epochs_{epochs}"
# pickle.dump(model, open(f"models/transfer_learning/model_{dt}_{deets}.pkl", 'wb'))
# pickle.dump(diffusion, open(f"models/transfer_learning/diffusion_{dt}_{deets}.pkl", 'wb'))

In [None]:
# with open('models/transfer_learning/model_0213_1014AM_diff_steps_1500_epochs_10000.pkl', 'rb') as handle:
#     model = pickle.load(handle)

In [None]:
# with open('models/transfer_learning/diffusion_0213_1014AM_diff_steps_1500_epochs_10000.pkl', 'rb') as handle:
#     diffusion = pickle.load(handle)

# (Only for transfer learning) Finetuning on only Shakespeare

In [None]:
for param in model.parameters():
    param.requires_grad = False

# model.lm_head0 = nn.Linear(hidden_dim, hidden_dim)
# model.lm_head0.weight.requires_grad_(True)
model.lm_head = torch.nn.Linear(hidden_dim, vocab_size)
model.lm_head.weight.requires_grad_(True)

## Check if we have enabled/disabled grad

In [None]:
model.word_embedding.weight.requires_grad

In [None]:
model.lm_head.weight.requires_grad

In [None]:
model.eval().to(dist_util.dev())

model_emb = torch.nn.Embedding(
        num_embeddings=tokenizer.vocab_size, 
        embedding_dim=hidden_dim, 
        _weight=model.word_embedding.weight.clone().cpu()
    ).eval()

In [None]:
ss_data = load_data_text(
        batch_size=10,
        seq_len=seq_len,
        data_dir=regular_data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_emb.cpu() # use model's weights as init
    )

In [None]:
model.train() # TURNING THE TRAIN MODE BACK ON TO ENABLE BATCHNORM/DROPOUT!!

In [None]:
model.word_embedding.weight.requires_grad_(True)

In [None]:
model.word_embedding.weight.requires_grad

In [None]:
model.lm_head.weight.requires_grad

In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=ss_data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
#         eval_data=data_valid,
        eval_interval=eval_interval
    ).run_loop()

In [None]:
# dt = datetime.now().strftime("%m%d_%I%M%p")
# deets = f"diff_steps_{diffusion_steps}_epochs_{epochs}"
# pickle.dump(model, open(f"models/transfer_learning/model_{dt}_{deets}.pkl", 'wb'))
# pickle.dump(diffusion, open(f"models/transfer_learning/diffusion_{dt}_{deets}.pkl", 'wb'))

## Comparing with model before transfer learning

In [None]:
# with open('models/transfer_learning/diffusion_0213_0609AM_diff_steps_1500_epochs_2.pkl', 'rb') as handle:
#     diffusion_ori = pickle.load(handle)
    
# with open('models/transfer_learning/model_0213_0609AM_diff_steps_1500_epochs_2.pkl', 'rb') as handle:
#     model_ori = pickle.load(handle)

In [None]:
# model_ori.lm_head.weight == model.lm_head.weight

In [None]:
# model_ori.word_embedding.weight == model.word_embedding.weight

In [None]:
# model_ori.input_up_proj[0].weight == model.input_up_proj[0].weight

### Nice! Worked

# Generating sequences

In [None]:
import numpy as np
device = dist_util.dev()

def sampling(
    model, 
    diffusion, 
    tokenizer, 
    device=device, 
    batch_size=16, 
    seq_len=128, 
    data_dir='data/',
    split='test',
    clip_denoised=False, 
    model_kwargs={}, 
    top_p=0, 
    step_gap=1,
    clamp_step=0):
    
    # ---- putting the model into eval mode ----
    model.eval().requires_grad_(False).to(device)
    
    hidden_dim = model.word_embedding.embedding_dim

    model_emb = torch.nn.Embedding(
            num_embeddings=tokenizer.vocab_size, 
            embedding_dim=hidden_dim, 
            _weight=model.word_embedding.weight.clone().cpu()
        ).eval().requires_grad_(False)
    
    # ---- getting test data ----

    data_test = load_data_text(
            batch_size=batch_size,
            seq_len=seq_len,
            deterministic=True,
            data_dir=data_dir,
            split=split,
            loaded_vocab=tokenizer,
            model_emb=model_emb.cpu(),  # using the same embedding wight with tranining data
            loop=False
        )

    all_test_data = []

    try:
        while True:
            batch, cond = next(data_test)
            # print(batch.shape)
            all_test_data.append(cond)

    except StopIteration:
        print('### End of reading iteration...')

    model_emb.to(device)
    
    # ---- iterating through the test data to generate sequences ----
    iterator = iter(all_test_data)
    word_lst_recover = []
    word_lst_ref = []
    word_lst_source = []

    for cond in iterator:

        input_ids_x = cond.pop('input_ids').to(device)
        x_start = model.get_embeds(input_ids_x)
        input_ids_mask = cond.pop('input_mask')
        input_ids_mask_ori = input_ids_mask

        noise = torch.randn_like(x_start)
        input_ids_mask = torch.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(device)
        x_noised = torch.where(input_ids_mask == 0, x_start, noise)

        model_kwargs = {}
        sample_fn = diffusion.p_sample_loop
        sample_shape = (x_start.shape[0], seq_len, hidden_dim)

        samples = sample_fn(
            model,
            sample_shape,
            noise=x_noised,
            clip_denoised=clip_denoised,
            denoised_fn=partial(denoised_fn_round, model_emb),
            model_kwargs=model_kwargs,
            top_p=top_p,
            clamp_step=clamp_step,
            clamp_first=True,
            mask=input_ids_mask,
            x_start=x_start,
            gap=step_gap
        )

        # print(samples[0].shape) # samples for each step

        sample = samples[-1]

        # print('decoding for seq2seq', )
        # print(sample.shape)

        logits = model.get_logits(sample)  # bsz, seqlen, vocab
        cands = torch.topk(logits, k=1, dim=-1)

        for seq, input_mask in zip(cands.indices, input_ids_mask_ori):
            len_x = seq_len - sum(input_mask).tolist()
            tokens = tokenizer.decode_token(seq[len_x:])
            word_lst_recover.append(tokens)

        for seq, input_mask in zip(input_ids_x, input_ids_mask_ori):
            # tokens = tokenizer.decode_token(seq)
            len_x = seq_len - sum(input_mask).tolist()
            word_lst_source.append(tokenizer.decode_token(seq[:len_x]))
            word_lst_ref.append(tokenizer.decode_token(seq[len_x:]))
    
    return word_lst_source, word_lst_recover, word_lst_ref


def get_efficient_knn(model_emb, text_emb):
    emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
    text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
    arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
    # print(emb_norm.shape, arr_norm.shape)
    dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
    dist = torch.clamp(dist, 0.0, np.inf)
    # print(dist.shape)
    topk_out = torch.topk(-dist, k=1, dim=0)
    return topk_out.values, topk_out.indices

def denoised_fn_round(model, text_emb, t):
    # print(text_emb.shape) # bsz, seqlen, dim
    model_emb = model.weight  # input_embs
    # print(t)
    old_shape = text_emb.shape
    old_device = text_emb.device

    if len(text_emb.shape) > 2:
        text_emb = text_emb.reshape(-1, text_emb.size(-1))
    else:
        text_emb = text_emb
    # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist)
    val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
    rounded_tokens = indices[0]
    # print(rounded_tokens.shape)
    new_embeds = model(rounded_tokens).view(old_shape).to(old_device)

    return new_embeds

In [None]:
clip_denoised = False
top_p = 0
clamp_step = 0

In [None]:
word_lst_source, word_lst_recover, word_lst_ref = sampling(model, diffusion, tokenizer, data_dir=regular_data_dir, batch_size=5, split='test_custom', seq_len=20)

Generating 20 sentences takes 5 minutes

In [None]:
word_lst_source

In [None]:
word_lst_recover

In [None]:
word_lst_ref