In [1]:
import json
import gzip
import csv
import sys
import numpy as np
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
device = 'cuda'
float16 = True
max_length = 100

input_fn = "Oxford_Comma_Data/train_oxford.csv"
output_fn = "Oxford_Comma_Data/output:train_oxford.csv"

In [3]:
model_name = 'EleutherAI/gpt-j-6B'
tokenizer = AutoTokenizer.from_pretrained(model_name)

if float16:
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16, return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

In [4]:
in_data = list(csv.reader(open(input_fn, 'rt')))
header = in_data[0]
in_data = in_data[1:]

In [5]:
in_data[0]

['0',
 'I used FDT and Starling creating an Adobe AIR (ActionScript) project, all tools or frameworks I already had some knowledge with.',
 'False',
 '79']

In [6]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [7]:
for i, line in tqdm(enumerate(in_data), total=len(in_data)):
    line_idx, sentence, contains, char_idx = line
    contains, char_idx = contains == 'True', int(char_idx)

    prefix = sentence[:char_idx]
    input_ids = tokenizer.encode(prefix, \
                                 return_tensors='pt', \
                                 padding=False, \
                                 max_length=max_length \
                                ).to(device)

    
    # Evaluate the loss of the sequence with the GPT-2 model
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    # Get the loss at each token
    last_logits = logits[...,-1,:].contiguous().squeeze(0)
    probs = torch.nn.Softmax(dim=-1)(last_logits)

    
    comma_idx = 11 #tokenizer.encode(',') no leading white space
    comma_prob = probs[comma_idx]
    
    
    if False:
        print(contains)
        print(sentence)
        print(prefix)
        print(input_ids)
        print(comma_prob)
        print()
        
        if i == 3:
            break
            
    out.writerow([i, prefix, contains, comma_prob.item()])

  0%|          | 0/13082 [00:00<?, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [8]:
out_fh.close()

In [9]:
tokenizer.encode('The stellar-backed Wirex Stablecoins is unique for international remittance, and offers faster,')

[464,
 25041,
 12,
 17078,
 14712,
 87,
 520,
 540,
 14624,
 318,
 3748,
 329,
 3230,
 816,
 47912,
 11,
 290,
 4394,
 5443,
 11]