In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
from exp1 import load_jsonl
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence
import torch
import torch.nn.functional as F
from torch import nn


  from .autonotebook import tqdm as notebook_tqdm


In [6]:

def evaluate_priming_effect(model, tokenizer, primer, index, device, k=5):
    model.eval()
    with torch.no_grad():
        # Encode the input text and send to the specified device
        encoded_input = tokenizer(primer, return_tensors='pt').to(device)
        outputs = model(**encoded_input)
        logits = outputs.logits

        # Check if the desired index is within the length of the sequence
        if index < logits.size(1):
            logits_at_index = logits[0, index, :]
            probs_at_index = F.softmax(logits_at_index, dim=0)
            top_probs, top_indices = torch.topk(probs_at_index, k)

            predictions = {}
            for idx, prob in zip(top_indices, top_probs):
                token = tokenizer.decode([idx.item()])
                predictions[token] = prob.item()
            
            return predictions
        else:
            return {"error": "Index out of bounds"}
            
def evaluate_priming_effect_at_index(sentence,index,k=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model_name = 'gpt2'
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    predictions = evaluate_priming_effect(model, tokenizer, sentence, index, device, k)
    return predictions
    
def main():
    sentence = "This is a crazy world."
    index = 2     
    predictions = evaluate_priming_effect_at_index(sentence, index,k=5)
    print("Predictions:", predictions)

In [7]:
if __name__ == "__main__":
    main()

Using device: cpu




Predictions: {' very': 0.05330051854252815, ' great': 0.042423609644174576, ' good': 0.024849338456988335, ' big': 0.013589226640760899, ' huge': 0.011689825914800167}
