In [1]:
from tqdm import tqdm
import csv

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [9]:
device = 'cuda'
model_name = 'EleutherAI/gpt-j-6B'
model_precision = "float16"
target_token_idx = 11
max_length = 2048
input_fn = '../out/oxford_comma/head_train_nocommas_extract.csv'
output_fn = '../out/oxford_comma/head_train_nocommas_scored.csv'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side='left')
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

In [10]:
in_data = list(csv.reader(open(input_fn, 'rt')))
header = in_data[0]
in_data = in_data[1:]

In [11]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [12]:
for i, line in tqdm(enumerate(in_data), total=len(in_data)):
    line_idx, sentence, contains, char_idx = line
    contains, char_idx = contains == 'True', int(char_idx)

    prefix = sentence[:char_idx]
    input_ids = tokenizer.encode(prefix, \
                                 return_tensors='pt', \
                                 max_length=max_length, \
                                 padding=False).to(device)
    # i checked, it is left truncate

    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[..., -1, :].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    # comma_idx = 11
    final_prob = probs[target_token_idx]

    out.writerow([line_idx, input_ids.shape[1], contains, final_prob.item(), probs.argmax().item()])

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 2106/2106 [02:32<00:00, 13.77it/s]


In [13]:
out_fh.close()