In [None]:
# pip install datasets --user

In [1]:
from run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from train import TrainLoop
from utils.data import load_data_text
from tokenizer import load_tokenizer, load_model_emb
from sampling import sampling

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch, os
from utils import dist_util
from functools import partial
import pickle
import random

2024-02-07 11:40:33.668637: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# with open('vocab_list.pickle', 'rb') as handle:
#     vocab_list = pickle.load(handle)
# vocab_list = list(vocab_list.values())

In [3]:
dist_util.clear_cache()

In [4]:
lr=0.0001
batch_size=10
microbatch=5
epochs=100
eval_interval=1000
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=1000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
seq_len=500
hidden_t_dim=128
hidden_dim=128
dropout=0.1
seed=102
weight_decay=0.0
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [5]:
cc_data_dir='data/commonsense'
ss_data_dir='data/shakespeare'
ss_small_data_dir='data/mini-shakespeare'
combined_data_dir='data/combined'
regular_data_dir='data'

# set the data directory
data_dir=regular_data_dir

In [6]:
set_seed(seed)

In [7]:
tokenizer = load_tokenizer('shakespeare', config_name)

In [8]:
tokenizer.vocab_size

30473

In [9]:
tokenizer.encode_token('find we a time for fright peace to pant')

[2, 611, 139, 26, 413, 121, 3021, 649, 94, 7338, 3]

In [10]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [11]:
model_weight

Embedding(30473, 128)

In [12]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size

In [13]:
vocab_size

30473

In [14]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the TRAIN set...
### Data samples...
 ["proceed solinus to procure my fall <EOS> and by the doom of death end woes and all <EOS> therefore give out you are of epidamnum <EOS> lest that your goods too soon be confiscate <EOS> this very day a syracusian merchant <EOS> is apprehended for arrival here <EOS> and not be able to buy out his life <EOS> accord to the statute of the town <EOS> dies ere the weary sun set in the west <EOS> there is your money that i had to keep <EOS> antipholus <EOS> neither my husband nor the slave return'd <EOS> that in such haste i sent to seek his master! <EOS> sure luciana it is two o'clock <EOS> the gold i gave to dromio is laid up <EOS> safe at the centaur and the heedful slave <EOS> is wander'd forth in care to seek me out <EOS> by computation and mine host's report <EOS> i could not speak with dromio since at first <EOS> i sen

Running tokenizer on dataset (num_proc=4):   0%|          | 0/5083 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 5083
})
### tokenized_datasets...example [2, 1672, 13682, 94, 6298, 111, 704, 21, 84, 22, 90, 195, 83, 2928, 100, 454, 678, 3699, 90, 192, 21, 84, 22, 582, 353, 321, 95, 206, 100, 10631, 21, 84, 22, 1581, 116, 141, 5487, 344, 1237, 108, 11563, 21, 84, 22, 149, 506, 390, 26, 10571, 4200, 21, 84, 22, 126, 9739, 121, 11517, 252, 21, 84, 22, 90, 131, 108, 3085, 94, 2110, 321, 148, 476, 21, 84, 22, 2679, 94, 83, 13799, 100, 83, 1496, 21, 84, 22, 1935, 821, 83, 2565, 1037, 650, 109, 83, 2468, 21, 84, 22, 218, 126, 141, 1340, 116, 34, 335, 94, 595, 21, 84, 22, 1207, 21, 84, 22, 1351, 111, 844, 383, 83, 1738, 998, 7, 29, 21, 84, 22, 116, 109, 349, 1447, 34, 924, 94, 1060, 148, 452, 5, 21, 84, 22, 869, 7797, 137, 126, 542, 40, 7, 2239, 21, 84, 22, 83, 919, 34, 1184, 94, 3009, 126, 2189, 221, 21, 84, 22, 1815, 214, 83, 10590, 90, 83, 12337, 1738, 21, 84, 22, 126, 3618, 7, 29, 808, 109, 1029, 94, 1060, 125,

merge and mask:   0%|          | 0/5083 [00:00<?, ? examples/s]

RAM used: 1105.20 MB


padding:   0%|          | 0/5083 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 5083
}) padded dataset
RAM used: 1187.09 MB
RAM used: 1159.08 MB


Passed in as batch in TrainLoop - this is the batch data

In [15]:
# next(data)[0].shape # batch_size, seq_len, hidden_dim

In [16]:
# next(data)[0]

Passed in as cond in TrainLoop - this is a dictionary of input_ids and input_mask

In [17]:
# next(data)[1]

In [18]:
# next(data)[1]['input_ids'].shape # batch_size, hidden_dim

In [19]:
# next(data)[1]['input_mask'].shape # batch_size, hidden_dim

In [20]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

In [21]:
model.to(dist_util.dev())

TransformerNetModel(
  (word_embedding): Embedding(30473, 128)
  (lm_head): Linear(in_features=128, out_features=30473, bias=True)
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=768, bias=True)
  )
  (input_up_proj): Sequential(
    (0): Linear(in_features=128, out_features=768, bias=True)
    (1): Tanh()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_transformers): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768,

In [22]:
pytorch_total_params = sum(p.numel() for p in model.parameters())

In [23]:
pytorch_total_params

91218953

In [24]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [25]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
#         eval_data=data_valid,
        eval_interval=eval_interval
    ).run_loop()

Epoch 0 Loss: 1.0473501682281494
Epoch 1 Loss: 1.0449873208999634
Epoch 2 Loss: 0.9932296276092529
Epoch 3 Loss: 0.9570184350013733
Epoch 4 Loss: 0.9466387033462524
Epoch 5 Loss: 0.9344589710235596
Epoch 6 Loss: 0.9199233651161194
Epoch 7 Loss: 0.9016681909561157
Epoch 8 Loss: 0.8636384606361389
Epoch 9 Loss: 0.7462528944015503
Epoch 10 Loss: 0.6727670431137085
Epoch 11 Loss: 0.5101889967918396
Epoch 12 Loss: 0.6523962020874023
Epoch 13 Loss: 0.8114101886749268
Epoch 14 Loss: 1.133377194404602
Epoch 15 Loss: 1.0206667184829712
Epoch 16 Loss: 0.9286031723022461
Epoch 17 Loss: 0.9162858128547668
Epoch 18 Loss: 0.8699871301651001
Epoch 19 Loss: 0.8511697053909302
Epoch 20 Loss: 0.846432089805603
Epoch 21 Loss: 0.7838535308837891
Epoch 22 Loss: 0.8382791876792908
Epoch 23 Loss: 0.8339088559150696
Epoch 24 Loss: 0.7853477001190186
Epoch 25 Loss: 0.803545355796814
Epoch 26 Loss: 0.8014354109764099
Epoch 27 Loss: 0.7737951278686523
Epoch 28 Loss: 0.7624783515930176
Epoch 29 Loss: 0.8051172494

In [26]:
# pickle.dump(model, open("models/model_0204.pkl", 'wb'))

In [27]:
# with open('models/model_0204.pkl', 'rb') as handle:
#     model = pickle.load(handle)

# Finetuning on only Shakespeare

In [28]:
# model.eval().to(dist_util.dev())

# model_emb = torch.nn.Embedding(
#         num_embeddings=tokenizer.vocab_size, 
#         embedding_dim=hidden_dim, 
#         _weight=model.word_embedding.weight.clone().cpu()
#     ).eval()

In [29]:
# ss_data = load_data_text(
#         batch_size=10,
#         seq_len=seq_len,
#         data_dir=ss_data_dir,
#         loaded_vocab=tokenizer,
#         model_emb=model_emb.cpu() # use model's weights as init
#     )

In [30]:
# model.train() # TURNING THE TRAIN MODE BACK ON TO ENABLE BATCHNORM/DROPOUT!!

# TrainLoop(
#         model=model,
#         diffusion=diffusion,
#         data=ss_data,
#         batch_size=batch_size,
#         microbatch=microbatch,
#         lr=lr,
#         ema_rate=ema_rate,
#         schedule_sampler=schedule_sampler,
#         weight_decay=weight_decay,
#         epochs=epochs,
# #         eval_data=data_valid,
#         eval_interval=eval_interval
#     ).run_loop()

# Generating sequences

In [36]:
import numpy as np
device = dist_util.dev()

def sampling(
    model, 
    diffusion, 
    tokenizer, 
    device=device, 
    batch_size=16, 
    seq_len=128, 
    data_dir='data/',
    sampling_step=1000, 
    diffusion_steps=1000,
    clip_denoised=False, 
    model_kwargs={}, 
    top_p=0, 
    clamp_step=0):
    
    # ---- putting the model into eval mode ----
    model.eval().to(device)
    
    hidden_dim = model.word_embedding.embedding_dim

    model_emb = torch.nn.Embedding(
            num_embeddings=tokenizer.vocab_size, 
            embedding_dim=hidden_dim, 
            _weight=model.word_embedding.weight.clone().cpu()
        ).eval()
    
    # ---- getting test data ----

    data_test = load_data_text(
            batch_size=batch_size,
            seq_len=seq_len,
            deterministic=True,
            data_dir=data_dir,
            split="test",
            loaded_vocab=tokenizer,
            model_emb=model_emb.cpu(),  # using the same embedding wight with tranining data
            loop=False
        )

    all_test_data = []

    try:
        while True:
            batch, cond = next(data_test)
            # print(batch.shape)
            all_test_data.append(cond)

    except StopIteration:
        print('### End of reading iteration...')

    model_emb.to(device)
    
    # ---- iterating through the test data to generate sequences ----
    iterator = iter(all_test_data)
    word_lst_recover = []
    word_lst_ref = []
    word_lst_source = []

    for cond in iterator:

        input_ids_x = cond.pop('input_ids').to(device)
        x_start = model.get_embeds(input_ids_x)
        input_ids_mask = cond.pop('input_mask')
        input_ids_mask_ori = input_ids_mask

        noise = torch.randn_like(x_start)
        input_ids_mask = torch.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(device)
        x_noised = torch.where(input_ids_mask == 0, x_start, noise)

        model_kwargs = {}

        if sampling_step == diffusion_steps:
            use_ddim = False
            step_gap = 1
        else:
            use_ddim = True
            step_gap = diffusion_steps//sampling_step

        sample_fn = (
            diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop
        )

        sample_shape = (x_start.shape[0], seq_len, hidden_dim)

        samples = sample_fn(
            model,
            sample_shape,
            noise=x_noised,
            clip_denoised=clip_denoised,
            denoised_fn=partial(denoised_fn_round, model_emb),
            model_kwargs=model_kwargs,
            top_p=top_p,
            clamp_step=clamp_step,
            clamp_first=True,
            mask=input_ids_mask,
            x_start=x_start,
            gap=step_gap
        )

        # print(samples[0].shape) # samples for each step

        sample = samples[-1]

        # print('decoding for seq2seq', )
        # print(sample.shape)

        logits = model.get_logits(sample)  # bsz, seqlen, vocab
        cands = torch.topk(logits, k=1, dim=-1)

        for seq, input_mask in zip(cands.indices, input_ids_mask_ori):
            len_x = seq_len - sum(input_mask).tolist()
            tokens = tokenizer.decode_token(seq[len_x:])
            word_lst_recover.append(tokens)

        for seq, input_mask in zip(input_ids_x, input_ids_mask_ori):
            # tokens = tokenizer.decode_token(seq)
            len_x = seq_len - sum(input_mask).tolist()
            word_lst_source.append(tokenizer.decode_token(seq[:len_x]))
            word_lst_ref.append(tokenizer.decode_token(seq[len_x:]))
        break # after 1 batch
    return word_lst_source, word_lst_recover, word_lst_ref


def get_efficient_knn(model_emb, text_emb):
    emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
    text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
    arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
    # print(emb_norm.shape, arr_norm.shape)
    dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
    dist = torch.clamp(dist, 0.0, np.inf)
    # print(dist.shape)
    topk_out = torch.topk(-dist, k=1, dim=0)
    return topk_out.values, topk_out.indices

def denoised_fn_round(model, text_emb, t):
    # print(text_emb.shape) # bsz, seqlen, dim
    model_emb = model.weight  # input_embs
    # print(t)
    old_shape = text_emb.shape
    old_device = text_emb.device

    if len(text_emb.shape) > 2:
        text_emb = text_emb.reshape(-1, text_emb.size(-1))
    else:
        text_emb = text_emb
    # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist)
    val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
    rounded_tokens = indices[0]
    # print(rounded_tokens.shape)
    new_embeds = model(rounded_tokens).view(old_shape).to(old_device)

    return new_embeds

In [33]:
data_dir

'data'

In [34]:
sampling_step = 1000
clip_denoised = False
model_kwargs = {}
top_p = 0
clamp_step = 0

In [37]:
word_lst_source, word_lst_recover, word_lst_ref = sampling(model, diffusion, tokenizer, sampling_step=sampling_step, data_dir=data_dir, batch_size=5)

############################## 
Loading text data...
############################## 
Loading dataset from data...
### Loading form the TEST set...
### Data samples...
 ["proceed solinus to procure my fall <EOS> and by the doom of death end woes and all <EOS> therefore give out you are of epidamnum <EOS> lest that your goods too soon be confiscate <EOS> this very day a syracusian merchant <EOS> is apprehended for arrival here <EOS> and not be able to buy out his life <EOS> accord to the statute of the town <EOS> dies ere the weary sun set in the west <EOS> there is your money that i had to keep <EOS> antipholus <EOS> neither my husband nor the slave return'd <EOS> that in such haste i sent to seek his master! <EOS> sure luciana it is two o'clock <EOS> the gold i gave to dromio is laid up <EOS> safe at the centaur and the heedful slave <EOS> is wander'd forth in care to seek me out <EOS> by computation and mine host's report <EOS> i could not speak with dromio since at first <EOS> i sent

Running tokenizer on dataset (num_proc=4):   0%|          | 0/5083 [00:00<?, ? examples/s]

### tokenized_datasets Dataset({
    features: ['input_id_x', 'input_id_y'],
    num_rows: 5083
})
### tokenized_datasets...example [2, 1672, 13682, 94, 6298, 111, 704, 21, 84, 22, 90, 195, 83, 2928, 100, 454, 678, 3699, 90, 192, 21, 84, 22, 582, 353, 321, 95, 206, 100, 10631, 21, 84, 22, 1581, 116, 141, 5487, 344, 1237, 108, 11563, 21, 84, 22, 149, 506, 390, 26, 10571, 4200, 21, 84, 22, 126, 9739, 121, 11517, 252, 21, 84, 22, 90, 131, 108, 3085, 94, 2110, 321, 148, 476, 21, 84, 22, 2679, 94, 83, 13799, 100, 83, 1496, 21, 84, 22, 1935, 821, 83, 2565, 1037, 650, 109, 83, 2468, 21, 84, 22, 218, 126, 141, 1340, 116, 34, 335, 94, 595, 21, 84, 22, 1207, 21, 84, 22, 1351, 111, 844, 383, 83, 1738, 998, 7, 29, 21, 84, 22, 116, 109, 349, 1447, 34, 924, 94, 1060, 148, 452, 5, 21, 84, 22, 869, 7797, 137, 126, 542, 40, 7, 2239, 21, 84, 22, 83, 919, 34, 1184, 94, 3009, 126, 2189, 221, 21, 84, 22, 1815, 214, 83, 10590, 90, 83, 12337, 1738, 21, 84, 22, 126, 3618, 7, 29, 808, 109, 1029, 94, 1060, 125,

merge and mask:   0%|          | 0/5083 [00:00<?, ? examples/s]

RAM used: 3298.45 MB


padding:   0%|          | 0/5083 [00:00<?, ? examples/s]

Dataset({
    features: ['input_id_x', 'input_id_y', 'input_ids', 'input_mask'],
    num_rows: 5083
}) padded dataset
RAM used: 3356.30 MB
RAM used: 3323.42 MB
### End of reading iteration...


  0%|          | 0/1000 [00:00<?, ?it/s]

Generating 20 sentences takes 5 minutes

In [38]:
word_lst_source

['[CLS] proceed solinus to procure my fall < eos > and by the doom of death end woes and all < eos > therefore give out you are of epidamnum < eos > lest that your goods too soon be confiscate < eos > this very day a syracusian merchant < eos > is apprehended for arrival here < eos > and [SEP] [SEP]',
 '[CLS] merchant of syracuse plead no more < eos > i am not partial to infringe our laws [UNK] < eos > the enmity and discord which of late < eos > sprung from the rancorous outrage of your duke < eos > to merchants our well - deal countrymen < eos > who want guilders to redeem their lives < eos [SEP] [SEP]',
 '[CLS] yet this my comfort [UNK] when your words are done < eos > my woes end likewise with the even sun < eos > many a man would take you at your word < eos > and go indeed hav so good a mean < eos > exit < eos > antipholus < eos > why should their liberty than ours [SEP] [SEP]',
 "[CLS] well syracusian say in brief the cause < eos > why thou departed'st from thy native home < eos 

In [39]:
word_lst_recover

['eos eos < < [PAD] > < spied < > > > i disheartens < < [PAD] < eos [PAD] < < eos < sociable beagle [PAD] < eos > < eos < eos justly scales < < > < justly < eos sociable i [PAD] spied justly < eos eos > spied < eos > > i < > < < < stre',
 '< > i [PAD] [PAD] lewdly eos < [PAD] sociable > > < subscribed justly [PAD] > < eos < i eos eos < spied < < < [PAD] eos eos < < < i eos < < < eos < spied lewdly sociable < < < < < < < > < < > [PAD] justly eos < eos > sociable',
 'lewdly eos < scales eos [PAD] < justly eos [PAD] sociable >urday spied > < [PAD] < < subscribed < < rasher eos < < > i [PAD] < < eos eos eos justly > < < < > < > le sociable < < < beagle < > < eos eos < eosurday lewdly < spied eos < justly > i',
 '< < stre > eos [PAD] < < < sociable > < > < < < eos < < eos sociable < > < lewdly < i scales eos > > < < > eos and < i eos i > < eos eos > < < < eos eos < > < eos < < [PAD] and spied < eos sociable < i',
 '##ends < i justly < > > < > justly eos < eos < < [PAD] < [PAD]ends [PAD] [PA

In [40]:
word_lst_ref

["[CLS] well syracusian say in brief the cause < eos > why thou departed'st from thy native home < eos > and for what cause thou camest to ephesus < eos > a trusty villain sir that very oft < eos > when i am dull with care and melancholy < eos > lightens my humour with his merry jests [SEP]",
 "[CLS] a heavier task could not have been imposed < eos > than i to speak my griefs unspeakable [UNK] < eos > yet that the world may witness that my end < eos > was wrought by nature not by vile offence < eos > i'll utter what my sorrows give me leave < eos > in syracusa was [SEP]",
 "[CLS] nay forward old man don't break off so < eos > for we may pity though not pardon thee < eos > farewell till then [UNK] i will go lose myself < eos > and wander up and down to view the city < eos > o know he is the bridle of your will < eos > i am [SEP]",
 "[CLS] o had the gods done so i had not now < eos > worthily term'd them merciless to us! < eos > for ere the ships could meet by twice five leagues < eos > 