In [1]:
import torch
import pandas as pd
import numpy as np
from wordfreq import word_frequency
from string import punctuation

from transformers import GPT2LMHeadModel, GPT2TokenizerFast

In [2]:
MODEL_NAME = {"gpt2":"gpt2"}

def get_predictions(sentence, model):
	model_name = MODEL_NAME[model]
	model = GPT2LMHeadModel.from_pretrained(model_name)
	tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

	sent_tokens = tokenizer.tokenize(sentence, add_special_tokens=True)
	indexed_tokens = tokenizer.convert_tokens_to_ids(sent_tokens)

	tokens_tensor = torch.tensor(indexed_tokens).unsqueeze(0)

	with torch.no_grad():
		probs = model(tokens_tensor)[0].softmax(dim=2).squeeze()

	return list(zip(sent_tokens, indexed_tokens, (None,) + probs.unbind()))

In [3]:
def get_surprisals(predictions):
	result = []
	for j, (word, word_idx, preds) in enumerate(predictions):
		if preds is None:
			surprisal = 0.0
		else:
			surprisal = - np.log(preds[word_idx].item() / np.log(2))
		result.append( (j+1, word, surprisal) )
	return result

In [66]:
def get_entropies(predictions):
    result = []
    for j, (word, word_idx, preds) in enumerate(predictions):
        if preds is None:
            entropy = 0.0
        else:
            surprisals = -1 * np.log(preds) / np.log(2)
            probs = preds
            entropy = (probs * surprisals).sum(-1)
        result.append((j, word, entropy))
    return result

In [38]:
predictions = get_predictions("I went to San Francisco and saw the golden gate pizza", "gpt2")

In [67]:
surprisals = get_surprisals(predictions)
print(surprisals)

entropies = get_entropies(predictions)
print(entropies)

[(1, 'I', 0.0), (2, 'Ġwent', 6.655623408750759), (3, 'Ġto', 0.5739821277583708), (4, 'ĠSan', 6.502838898518721), (5, 'ĠFrancisco', 0.5309294051813819), (6, 'Ġand', 1.8128097141296469), (7, 'Ġsaw', 2.681927388243818), (8, 'Ġthe', 1.4548104602111378), (9, 'Ġgolden', 7.920232738339079), (10, 'Ġgate', 5.872006630643133), (11, 'Ġpizza', 12.128894344681402)]


AttributeError: 'float' object has no attribute 'detach'

In [16]:
print(torch.log(predictions[1][2]))

tensor([ -6.5131,  -5.8791,  -9.7528,  ..., -14.6839, -12.4492,  -7.0846])
