In [None]:
# pip install datasets --user

In [None]:
from run_train import create_model_and_diffusion
from utils.step_sample import create_named_schedule_sampler
from train_util import TrainLoop
from utils.data import load_data_text

from transformers import AutoTokenizer, PreTrainedTokenizerFast, BertTokenizerFast, set_seed
import json, torch, os
from utils import dist_util
from functools import partial
import pickle
import random

In [None]:
# with open('vocab_list.pickle', 'rb') as handle:
#     vocab_list = pickle.load(handle)
# vocab_list = list(vocab_list.values())

In [None]:
dist_util.clear_cache()

In [None]:
lr=0.0001
batch_size=64
microbatch=20
epochs=100
eval_interval=1000
ema_rate='0.9999' 
schedule_sampler='uniform'
diffusion_steps=1000
noise_schedule='sqrt'
vocab='custom'
use_plm_init='no' # embedding in transformer
vocab_size=0
config_name='bert-base-uncased'
cc_data_dir='data/commonsense'
ss_data_dir='data/shakespeare'
data_dir=ss_data_dir
seq_len=128
hidden_t_dim=128
hidden_dim=64
dropout=0.1
seed=102
weight_decay=0.0
predict_xstart=True
rescale_timesteps=True
emb_scale_factor=1.0

In [None]:
set_seed(seed)

In [None]:
class myTokenizer():
    """
    Load tokenizer from bert config or defined BPE vocab dict
    """
    ################################################
    ### You can custome your own tokenizer here. ###
    ################################################
    def __init__(self, vocab, config_name, custom_vocab_list=None):
        if vocab == 'bert':
            tokenizer = AutoTokenizer.from_pretrained(config_name)
            self.tokenizer = tokenizer
            self.sep_token_id = tokenizer.sep_token_id
            self.pad_token_id = tokenizer.pad_token_id
        elif vocab == 'shakespeare':
            tokenizer = BertTokenizerFast('shakespeare-tokenizer-bert/vocab.txt')
            self.tokenizer = tokenizer
            self.sep_token_id = tokenizer.sep_token_id
            self.pad_token_id = tokenizer.pad_token_id
        elif vocab == 'combined':
            tokenizer = AutoTokenizer.from_pretrained(config_name)
            self.tokenizer = tokenizer
            self.sep_token_id = tokenizer.sep_token_id
            self.pad_token_id = tokenizer.pad_token_id
            self.tokenizer.add_tokens(custom_vocab_list)

        self.vocab_size = len(self.tokenizer)
    
    def encode_token(self, sentences):
        if isinstance(self.tokenizer, dict):
            input_ids = [[0] + [self.tokenizer.get(x, self.tokenizer['[UNK]']) for x in seq.split()] + [1] for seq in sentences]
        elif isinstance(self.tokenizer, PreTrainedTokenizerFast):
            input_ids = self.tokenizer(sentences, add_special_tokens=True)['input_ids']
        else:
            assert False, "invalid type of vocab_dict"
        return input_ids
        
    def decode_token(self, seq):
        if isinstance(self.tokenizer, dict):
            seq = seq.squeeze(-1).tolist()
            while len(seq)>0 and seq[-1] == self.pad_token_id:
                seq.pop()
            tokens = " ".join([self.rev_tokenizer[x] for x in seq]).replace('__ ', '').replace('@@ ', '')
        elif isinstance(self.tokenizer, PreTrainedTokenizerFast):
            seq = seq.squeeze(-1).tolist()
            while len(seq)>0 and seq[-1] == self.pad_token_id:
                seq.pop()
            tokens = self.tokenizer.decode(seq)
        else:
            assert False, "invalid type of vocab_dict"
        return tokens


def load_model_emb(hidden_dim, tokenizer):
    ### random emb or pre-defined embedding like glove embedding. You can custome your own init here.
    model = torch.nn.Embedding(tokenizer.vocab_size, hidden_dim)
    torch.nn.init.normal_(model.weight)

    return model, tokenizer


def load_tokenizer(vocab, config_name, custom_vocab_list=None):
    tokenizer = myTokenizer(vocab, config_name, custom_vocab_list=custom_vocab_list)
    return tokenizer

In [None]:
tokenizer = load_tokenizer('bert', config_name, custom_vocab_list=None)

In [None]:
tokenizer.tokenizer

In [None]:
tokenizer.encode_token('find we a time for fright peace to pant')

In [None]:
model_weight, tokenizer = load_model_emb(hidden_dim, tokenizer)

In [None]:
model_weight

In [None]:
## very very important to set this!!!!!
vocab_size = tokenizer.vocab_size

In [None]:
vocab_size

In [None]:
data = load_data_text(
        batch_size=batch_size,
        seq_len=seq_len,
        data_dir=data_dir,
        loaded_vocab=tokenizer,
        model_emb=model_weight # use model's weights as init
    )

Passed in as batch in TrainLoop - this is the batch data

In [None]:
# next(data)[0].shape # batch_size, seq_len, hidden_dim

In [None]:
# next(data)[0]

Passed in as cond in TrainLoop - this is a dictionary of input_ids and input_mask

In [None]:
# next(data)[1]

In [None]:
# next(data)[1]['input_ids'].shape # batch_size, hidden_dim

In [None]:
# next(data)[1]['input_mask'].shape # batch_size, hidden_dim

In [None]:
model, diffusion = create_model_and_diffusion(
                        hidden_t_dim,
                        hidden_dim,
                        vocab_size,
                        config_name,
                        use_plm_init,
                        dropout,
                        diffusion_steps,
                        noise_schedule,
                        predict_xstart,
                        rescale_timesteps,
                    )

In [None]:
model.to(dist_util.dev())

In [None]:
vocab_size, hidden_dim

In [None]:
dist_util.dev()

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters())

In [None]:
pytorch_total_params

In [None]:
schedule_sampler = create_named_schedule_sampler('uniform', diffusion)

In [None]:
schedule_sampler

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=batch_size,
        microbatch=microbatch,
        lr=lr,
        ema_rate=ema_rate,
        schedule_sampler=schedule_sampler,
        weight_decay=weight_decay,
        epochs=epochs,
#         eval_data=data_valid,
        eval_interval=eval_interval
    ).run_loop()

In [None]:
model.eval().requires_grad_(False).to(dist_util.dev())

In [None]:
model_emb = torch.nn.Embedding(
        num_embeddings=tokenizer.vocab_size, 
        embedding_dim=hidden_dim, 
        _weight=model.word_embedding.weight.clone().cpu()
    ).eval().requires_grad_(False)

In [None]:
data_test = load_data_text(
        batch_size=20,
        seq_len=seq_len,
        deterministic=True,
        data_dir=data_dir,
        split="test",
        loaded_vocab=tokenizer,
        model_emb=model_emb.cpu(),  # using the same embedding wight with tranining data
        loop=False
    )

In [None]:
all_test_data = []

idx = 0

try:
    while True:
        batch, cond = next(data_test)
        # print(batch.shape)
        all_test_data.append(cond)
        idx += 1

except StopIteration:
    print('### End of reading iteration...')

model_emb.to(dist_util.dev())


In [None]:
len(all_test_data) # number of batches

In [None]:
import numpy as np

def get_efficient_knn(model_emb, text_emb):
    emb_norm = (model_emb**2).sum(-1).view(-1, 1) # vocab
    text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
    arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
    # print(emb_norm.shape, arr_norm.shape)
    dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(model_emb, text_emb_t) # (vocab, d) x (d, bsz*seqlen)
    dist = torch.clamp(dist, 0.0, np.inf)
    # print(dist.shape)
    topk_out = torch.topk(-dist, k=1, dim=0)
    return topk_out.values, topk_out.indices

def denoised_fn_round(model, text_emb, t):
    # print(text_emb.shape) # bsz, seqlen, dim
    model_emb = model.weight  # input_embs
    # print(t)
    old_shape = text_emb.shape
    old_device = text_emb.device

    if len(text_emb.shape) > 2:
        text_emb = text_emb.reshape(-1, text_emb.size(-1))
    else:
        text_emb = text_emb
    # val, indices = get_knn(model_emb, text_emb.to(model_emb.device), dist=dist)
    val, indices = get_efficient_knn(model_emb, text_emb.to(model_emb.device))
    rounded_tokens = indices[0]
    # print(rounded_tokens.shape)
    new_embeds = model(rounded_tokens).view(old_shape).to(old_device)

    return new_embeds

In [None]:
step = 1000
clip_denoised = False
model_kwargs = {}
top_p = 0
clamp_step = 0

In [None]:
iterator = iter(all_test_data)
word_lst_recover = []
word_lst_ref = []
word_lst_source = []

for cond in iterator:

    input_ids_x = cond.pop('input_ids').to(dist_util.dev())
    x_start = model.get_embeds(input_ids_x)
    input_ids_mask = cond.pop('input_mask')
    input_ids_mask_ori = input_ids_mask

    noise = torch.randn_like(x_start)
    input_ids_mask = torch.broadcast_to(input_ids_mask.unsqueeze(dim=-1), x_start.shape).to(dist_util.dev())
    x_noised = torch.where(input_ids_mask == 0, x_start, noise)

    model_kwargs = {}

    if step == diffusion_steps:
        use_ddim = False
        step_gap = 1
    else:
        use_ddim = True
        step_gap = diffusion_steps//step

    sample_fn = (
        diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop
    )

    sample_shape = (x_start.shape[0], seq_len, hidden_dim)

    samples = sample_fn(
        model,
        sample_shape,
        noise=x_noised,
        clip_denoised=clip_denoised,
        denoised_fn=partial(denoised_fn_round, model_emb),
        model_kwargs=model_kwargs,
        top_p=top_p,
        clamp_step=clamp_step,
        clamp_first=True,
        mask=input_ids_mask,
        x_start=x_start,
        gap=step_gap
    )

    # print(samples[0].shape) # samples for each step

    sample = samples[-1]

    # print('decoding for seq2seq', )
    # print(sample.shape)

    logits = model.get_logits(sample)  # bsz, seqlen, vocab
    cands = torch.topk(logits, k=1, dim=-1)

#     word_lst_recover = []
#     word_lst_ref = []
#     word_lst_source = []

    # tokenizer = load_tokenizer(args)

    for seq, input_mask in zip(cands.indices, input_ids_mask_ori):
        len_x = seq_len - sum(input_mask).tolist()
        tokens = tokenizer.decode_token(seq[len_x:])
        word_lst_recover.append(tokens)

    for seq, input_mask in zip(input_ids_x, input_ids_mask_ori):
        # tokens = tokenizer.decode_token(seq)
        len_x = seq_len - sum(input_mask).tolist()
        word_lst_source.append(tokenizer.decode_token(seq[:len_x]))
        word_lst_ref.append(tokenizer.decode_token(seq[len_x:]))
    break # after 1 batch
    

In [None]:
cond

In [None]:
word_lst_source

In [None]:
word_lst_recover

In [None]:
word_lst_ref