In [1]:
import torch
import torch.nn.functional as F

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [3]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [5]:
# Start with start text.
starter_text = "I am"

generated = tokenizer.encode(starter_text)

# Start with beginning of sentence token.
# generated = [0]

context = torch.tensor([generated]).to(device)
past = None

N = 30
L = torch.full((1,), 0, dtype=torch.float32).to(device)
for i in range(N):

  # context shape: (batch_size, seq_len)
  with torch.no_grad():
    output = model(context, past_key_values=past)
    logits = output.logits # (batch_size, seq_len, vocab_len)
    past = output.past_key_values 

    probs = F.softmax(logits[:,-1,:], dim=-1)
    token = torch.multinomial(probs, num_samples=1)
    token_int = token[0].item()

    generated += [token_int]
    logprob = torch.log(probs[0,token_int])
    L += logprob
    context = token

sequence = tokenizer.decode(generated)

print(float(L/N))
print(sequence)

-4.6647047996521
I am having very desperate sevces issues and I am hoping to try to transfer from the SHO into French Point but I'm sure there's to be


In [12]:
text = "I WOULD NOT LIKE THEM HERE OR THERE. I WOULD NOT LIKE THEM ANYWHERE. I DO NOT LIKE GREEN EGGS AND HAM. I DO NOT LIKE THEM, SAM-I-AM."
x = torch.tensor([tokenizer.encode(text)]).to(device)

In [10]:
params = list(model.parameters())
grads = []

out = model(x, labels=x)
ll = -out.loss
ll.backward()

for i, p in enumerate(params):
    grads.append(p.grad)

In [11]:
for grad in grads:
    print(grad.shape)

torch.Size([50257, 768])
torch.Size([1024, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 2304])
torch.Size([2304])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 2304])
torch.Size([2304])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 2304])
torch.Size([2304])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 2304])
torch.Size([2304])
torch.Size([768, 768])
torch.Size([768])
torch.Size([768])
torch.Size([768])
torch.Size([768, 3072])
torch.Size([3072])
torch.Size([3072, 768])
torch