In [1]:
from text_denoiser import TextDenoiser
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
denoiser = TextDenoiser(embed_dim=228, dim_feedforward=1024)
denoiser.load_state_dict(torch.load("../logs/2000_lines/models/saved_model.pt", map_location=device))
print(denoiser.decoder.weight.shape)
print(denoiser.embedder.weight.shape)
# denoiser.decoder.weight = nn.Parameter(denoiser.embedder.weight)

torch.Size([10053, 228])
torch.Size([10053, 228])


In [3]:
n_T = 40
data = torch.ones(32, 4, 228)
factor = torch.linspace(0, n_T-1, n_T)
ts = torch.randint(0, n_T, (data.shape[1],))
print(factor)
print(ts)

data = data * factor[None, ts, None]
print(data[:, 1, :].mean())



tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.])
tensor([17, 29,  9,  0])
tensor(29.)


In [4]:
texts = denoiser.vocab.lookup_tokens(torch.randint(0, len(denoiser.vocab), (5,)).tolist())
print(texts)
indices = torch.LongTensor(denoiser.vocab(texts)).to(device)

embeds = denoiser.embedder(indices)
embeds.shape

decoded = denoiser.decoder(embeds)
print(decoded.shape)
new_indices = torch.argmax(F.softmax(decoded), dim=-1)
print(new_indices)
print(denoiser.vocab.lookup_tokens(new_indices.tolist()))

print(denoiser.vocab(texts))


['6th', 'entitled', 'brownish', 'steer', 'enraged']
torch.Size([5, 10053])
tensor([3635, 3160, 6918, 9531, 5221])
['6th', 'entitled', 'brownish', 'steer', 'enraged']
[3635, 3160, 6918, 9531, 5221]


  new_indices = torch.argmax(F.softmax(decoded), dim=-1)


### Validating that cross entropy is calculated correctly

In [5]:
batch_size = 64
seq_len = 128
indices = torch.randint(0, len(denoiser.vocab), (seq_len, batch_size))
embeds = denoiser.embedder(indices)
embeds.shape

decoded = denoiser.decoder(embeds)
# decoded = torch.randn(seq_len, batch_size, len(denoiser.vocab))
print("indices shape:", indices.shape)
print("decoded shape:", decoded.shape)

# Calculate cross entropy loss
# print(decoded.permute(0, 2, 1).shape)
loss = F.cross_entropy(decoded.permute(1, 2, 0), indices.T)
print(loss)


y = F.log_softmax(decoded, dim=-1).permute(0, 2, 1)
reconstruction_loss = F.cross_entropy(y, indices)
print(reconstruction_loss)

indices shape: torch.Size([128, 64])
decoded shape: torch.Size([128, 64, 10053])
tensor(0.0003, grad_fn=<NllLoss2DBackward0>)
tensor(0.0003, grad_fn=<NllLoss2DBackward0>)


In [6]:
denoiser.sample(device, 1, 128)

['of in in of of of of in of but in of in in of of of of in the in in of of to of of in the of in of in in of of of in of and the in in in in of of to in in of in of in of in of the in of in in in of the in of in in in in in in of of in of in in of in of of in of in in in in of of of in in in in in in in of in of in in in in in in in in in in in in of @-@ of in in of in of in of in in in of']

In [17]:
denoiser.eval()
seq_len = 16
n = 1
intermediates = []
with torch.no_grad():
    x = torch.randn((seq_len, n, denoiser.embed_dim), device=device)
    for t in range(denoiser.n_T, 0, -1):
        x = denoiser.sample_step(x, t)
        if t % 100 == 0 or t == 1:
            intermediates.append(x)



In [25]:
stacked = torch.stack(intermediates, dim=0)
stacked.shape
arbitrary_word_emb = denoiser.embedder(torch.LongTensor(denoiser.vocab(["of"])).to(device))

for x_i in stacked:
    # print(x_i.shape)
    print(F.cosine_similarity(x_i[0], arbitrary_word_emb, dim=-1))


tensor([0.0495], grad_fn=<SumBackward1>)
tensor([0.2288], grad_fn=<SumBackward1>)
tensor([0.4992], grad_fn=<SumBackward1>)
tensor([0.7436], grad_fn=<SumBackward1>)
tensor([0.9075], grad_fn=<SumBackward1>)
tensor([0.9684], grad_fn=<SumBackward1>)
tensor([0.9856], grad_fn=<SumBackward1>)
tensor([0.9943], grad_fn=<SumBackward1>)
tensor([0.9978], grad_fn=<SumBackward1>)
tensor([0.9994], grad_fn=<SumBackward1>)
tensor([0.9999], grad_fn=<SumBackward1>)


In [30]:

batch_select = 0
for x_i in stacked:
    probs = F.softmax(denoiser.decoder(x_i), dim=-1)
    print(probs.shape)
    indices = indices = torch.multinomial(probs[:, batch_select], 1)[:, 0]
    print(indices.shape)
    # indices = denoiser.emb_to_indices(x_i)[:, batch_select]
    tokens = denoiser.vocab.lookup_tokens(indices.tolist())
    print(" ".join(tokens))
    probs[Dennis]

torch.Size([16, 1, 10053])
torch.Size([16])
be to . by that = as that with that is were were as is that
torch.Size([16, 1, 10053])
torch.Size([16])
be as real and itself were over that this and team to this around was before
torch.Size([16, 1, 10053])
torch.Size([16])
see as 80 up = from arms name god be but <unk> more are team not
torch.Size([16, 1, 10053])
torch.Size([16])
uses to @ from , an in the a became which <unk> they consists , october
torch.Size([16, 1, 10053])
torch.Size([16])
of of their this in including in in of as in of in and in as
torch.Size([16, 1, 10053])
torch.Size([16])
of of on been in . in in of act in of in in in and
torch.Size([16, 1, 10053])
torch.Size([16])
of of jordan of in s in in of of in of in in in of
torch.Size([16, 1, 10053])
torch.Size([16])
of of jordan of in of in in of of in of in in in of
torch.Size([16, 1, 10053])
torch.Size([16])
of of in of in of in in of of in of in in in of
torch.Size([16, 1, 10053])
torch.Size([16])
of of in of in of in in

In [28]:
words = ["in", "of", "the", ",", "mario", "peach"]
indices = torch.LongTensor(denoiser.vocab(words)).to(device)
embeds = denoiser.embedder(indices)
F.cosine_similarity(embeds[None, words.index("mario")], embeds[None, words.index("peach")])


tensor([0.0077], grad_fn=<SumBackward1>)

## BERT TESTING

In [36]:
from transformers import BertConfig, BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [97]:
# Setup bert pipeline

def get_embedding(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs)
    return outputs[0]


# text = "shovel bucket"
# tokens = tokenizer.tokenize(text)
# tokens = ["[CLS]"] + tokens + ["[SEP]"]

# token_ids = tokenizer.convert_tokens_to_ids(tokens)
# token_ids = torch.LongTensor(token_ids).unsqueeze(0)

# with torch.no_grad():
#     outputs = model(token_ids)
#     last_hidden_states = outputs[0]

# print(last_hidden_states[0, 1, 0:3])

emb1 = get_embedding("", model, tokenizer)
emb2 = get_embedding("bathroom", model, tokenizer)

print(emb1[0, 1].norm())
print(emb1.shape[-1] ** 0.5)

# Cosine similarity between "shovel" and "bucket" embeddings
# print(torch.nn.functional.cosine_similarity(last_hidden_states[0, 1:2], last_hidden_states[0, 2:3]))
# print(model.get_output_embeddings())

# Cosine similarity between embeddings
# print(torch.nn.functional.cosine_similarity(emb1[0, 0:1], emb2[0, 0]))

# inverse process: get text from embedding
tokens = tokenizer("Hello I'm the new king")["input_ids"]
print(tokens)
tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(tokens))



tensor(15.0457, grad_fn=<NormBackward1>)
27.712812921102035
[101, 7592, 1045, 1005, 1049, 1996, 2047, 2332, 102]


"[CLS] hello i ' m the new king [SEP]"