In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!ls /content/drive/MyDrive/factor

cache  eval_factuality.py      LICENSE	README.md
data   FACTOR_NEWS_LICENSE.md  output	requirements.txt


In [None]:
import argparse
import os

import numpy as np
import pandas as pd
import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

# load data
def extract_example(row):
    return {'full_prefix': row.full_prefix, 'completion': row.completion,
            'contradictions': [row.contradiction_0, row.contradiction_1, row.contradiction_2]}


def read_data(path, prefix_col):
    df = pd.read_csv(path)[[prefix_col, 'doc_id', 'completion', 'contradiction_0', 'contradiction_1', 'contradiction_2']]
    df.rename(columns={prefix_col: 'full_prefix'}, inplace=True)
    return df.apply(lambda row: extract_example(row), axis=1).to_list()

# load model
def load_tokenizer(model_name, max_tokens):
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='right', truncation_side='left',
                                              model_max_length=max_tokens)
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def load_model_and_tokenizer(model_name, cache_dir=None, max_tokens=1024):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    multi_gpus = torch.cuda.device_count() > 1
    config = AutoConfig.from_pretrained(model_name)
    model_args = {}
    if cache_dir is not None and device != 'cpu':
        model_args["cache_dir"] = cache_dir
    if multi_gpus:
        model_args["device_map"] = "auto"
        model_args["low_cpu_mem_usage"] = True
    if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
        model_args["torch_dtype"] = config.torch_dtype

    model = AutoModelForCausalLM.from_pretrained(model_name, **model_args).eval()
    if not multi_gpus:
        model = model.to(device)
    tokenizer = load_tokenizer(model_name, max_tokens)
    print(model.dtype)
    model.config.pad_token_id = model.config.eos_token_id
    return model, tokenizer, device

# prepare examples for evaluation
def format_data(ex):
    prefix = ex['full_prefix']
    completion = ex['completion']
    contradictions = ex['contradictions']

    # make sure completion don't contain trailing spaces
    completion = completion.lstrip(' ')
    contradictions = [cont.lstrip(' ') for cont in contradictions]

    # if the prefix ends with a new line, just concatenate.
    # Else, add space to the completion, remove it from the prefix if necessary
    if prefix.endswith(' '):
        prefix = prefix[:-1]
        batch = [f"{prefix} {completion}"] + [f"{prefix} {cont}" for cont in contradictions]
        labels_batch = [f" {completion}"] + [f" {cont}" for cont in contradictions]
    else:
        batch = [f"{prefix}{completion}"] + [f"{prefix}{cont}" for cont in contradictions]
        labels_batch = [completion] + contradictions
    return batch, labels_batch


def prep_batch(ex, tokenizer, device):
    # prepare examples for tokenization
    batch, labels_batch = format_data(ex)
    # encode full text (context + completions)
    encoding = tokenizer(batch, padding=True, truncation=True, return_tensors='pt', add_special_tokens=False).to(device)
    encoding = {k: v.to(device) for k, v in encoding.items()}
    input_ids = encoding['input_ids']
    # extract labels from input text
    labels_encoding = tokenizer(labels_batch, padding=True, truncation=True, return_tensors='pt', add_special_tokens=False).to(device)
    input_lens = torch.sum(encoding['attention_mask'], axis=-1).to(device)
    target_lens = torch.sum(labels_encoding['attention_mask'], axis=-1).to(device)
    offsets = input_lens - target_lens
    positions = torch.arange(0, encoding['input_ids'].size(-1))[None, :].to(device)
    labels_mask = (positions >= offsets[:, None]) * encoding['attention_mask']

    labels = input_ids*labels_mask + (-100)*(1-labels_mask)

    # assert all labels match
    for input_id, label, target_len, offset, comp in zip(input_ids, labels, target_lens, offsets, labels_batch):
        assert torch.all(input_id[offset: offset + target_len].eq(label[offset:offset+target_len])), "labels don't appear in input ids"
        assert torch.all(label[:offset] == -100), "labels include redundant prefix"
        assert torch.all(label[offset + target_len:] == -100), "labels include redundant suffix"
    encoding = {k: v.to(device) for k, v in encoding.items()}
    return encoding, labels, target_lens


def get_losses(logits, labels):
    loss_fct = CrossEntropyLoss(reduction="none")
    nll = loss_fct(logits.reshape(-1, logits.size(-1)), labels.reshape(-1)).cpu()
    nll = nll.view(labels.size())
    return nll

import nltk
from nltk import pos_tag
from nltk.tokenize import word_tokenize

# Download the NLTK data (you only need to do this once)
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

def find_word_differences(sentence_a, sentence_b):
    words_a = set(sentence_a.split())
    words_b = set(sentence_b.split())

    words_only_in_a = words_a - words_b
    words_only_in_b = words_b - words_a

    return words_only_in_a, words_only_in_b

def get_part_of_speech(word):
    # Tokenize the word
    words = word_tokenize(word)

    # Perform part-of-speech tagging
    pos_tags = pos_tag(words)

    # Return the part of speech of the word
    return pos_tags[0][1] if pos_tags else None

def determine_error_type(sentence1, sentence2):
  differences_a, differences_b = find_word_differences(sentence1, sentence2)
  pos_a = []
  pos_b = []
  for ele in differences_a:
    pos_a.append(get_part_of_speech(ele))
  for ele in differences_b:
    pos_b.append(get_part_of_speech(ele))

  if 'PRP' in pos_b or 'PRP$' in pos_b or 'PRP' in pos_a or 'PRP$' in pos_a:
      error_type = 'Coreference'
  elif len(pos_b) > 3 or 'JJ' in pos_b or 'JJS' in pos_b or 'JJR' in pos_b:
      error_type = 'Entity'
  elif 'VB' in pos_b or 'VBD' in pos_b or 'VBD' in pos_b or 'VBN' in pos_b or 'VBP' in pos_b or 'VBZ' in pos_b or 'RB' in pos_b or 'RBR' in pos_b or 'RBS' in pos_b:
      error_type = 'Predicate'
  elif 'NN' in pos_b or 'NNP' in pos_b or 'NNS' in pos_b or 'NNPS' in pos_b or 'CD' in pos_b:
      error_type = 'Circumstance'
  elif 'IN' in pos_b:
      error_type = 'Link'
  else:
      error_type = 'Unclear'
  return error_type

def run_eval(model, tokenizer, data, device):
    all_scores = torch.empty((len(data), 4))
    fault_record = {'Coreference':0, 'Entity':0, 'Predicate':0, 'Circumstance':0, 'Link':0, 'Unclear':0}
    fault_num = {'Coreference':0, 'Entity':0, 'Predicate':0, 'Circumstance':0, 'Link':0, 'Unclear':0}
    for i, ex in tqdm(enumerate(data)):
        input_ids, target, target_lens = prep_batch(ex, tokenizer, device=device)
        with torch.no_grad():
            out = model(**input_ids)
            nll = get_losses(out.logits[..., :-1, :], target[:, 1:])

        # get scores for the full the sequence
        scores = torch.sum(nll, axis=-1)
        scores = scores / target_lens.to('cpu')
        all_scores[i] = scores
        for j in range(1,4):
            fake = ex['contradictions'][j-1]
            real = ex['completion']
            error_type = determine_error_type(real, fake)
            fault_num[error_type] += 1
            if np.argmin(np.array(scores.tolist())) == j:
                fault_record[error_type] += 1
        if i % 100 == 0:
            acc = np.sum(np.argmin(np.array(all_scores[:(i+1), :].tolist()), axis=1) == 0) / (i+1)
            print(f"processed: {i+1}/{len(data)} examples. accuracy: {acc}")
    for key in fault_record:
        if fault_num[key] != 0:
            fault_record[key] /= fault_num[key]
    return all_scores, fault_record


def main(data_file, output_folder, model_name, max_tokens, cache_dir):
    prefix_col = 'turncated_prefixes'
    data = read_data(data_file, prefix_col)
    model, tokenizer, device = load_model_and_tokenizer(model_name, cache_dir, max_tokens=max_tokens)
    all_scores, fault_record = run_eval(model, tokenizer, data, device)
    data = pd.DataFrame(data)
    data['scores'] = list(all_scores.to('cpu').numpy())
    acc = np.sum(np.argmin(np.array(data['scores'].to_list()), axis=1) == 0) / len(data)
    print(f"acc = {acc}")
    for key in fault_record:
        print(f'{key} error: {fault_record[key]}')
    data.to_json(get_results_path(output_folder, model_name), lines=True,
                 orient='records')
    print("Done!")


def get_results_path(output_folder, model_name):
    return os.path.join(output_folder, model_name.split('/')[-1] + '.jsonl')

'''
data_file = "/content/drive/MyDrive/factor/data/wiki_factor.csv"
output_folder = "/content/drive/MyDrive/factor/output"
model_name = 'gpt2'
max_tokens = 1024
cache_dir = "/content/drive/MyDrive/factor/cache"
main(data_file, output_folder, model_name, max_tokens, cache_dir)
'''

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


'\ndata_file = "/content/drive/MyDrive/factor/data/wiki_factor.csv"\noutput_folder = "/content/drive/MyDrive/factor/output"\nmodel_name = \'gpt2\'\nmax_tokens = 1024\ncache_dir = "/content/drive/MyDrive/factor/cache"\nmain(data_file, output_folder, model_name, max_tokens, cache_dir)\n'

In [None]:

data_file = "/content/drive/MyDrive/factor/data/wiki_factor.csv"
output_folder = "/content/drive/MyDrive/factor/output"
model_name = 'EleutherAI/gpt-neo-1.3B'
max_tokens = 1024
cache_dir = "/content/drive/MyDrive/factor/cache"
main(data_file, output_folder, model_name, max_tokens, cache_dir)

torch.float32


1it [00:03,  3.12s/it]

processed: 1/2994 examples. accuracy: 1.0


101it [02:16,  1.11s/it]

processed: 101/2994 examples. accuracy: 0.46534653465346537


201it [04:28,  1.93s/it]

processed: 201/2994 examples. accuracy: 0.4079601990049751


301it [06:47,  1.66s/it]

processed: 301/2994 examples. accuracy: 0.4152823920265781


401it [09:13,  1.27s/it]

processed: 401/2994 examples. accuracy: 0.4114713216957606


501it [11:40,  2.13s/it]

processed: 501/2994 examples. accuracy: 0.43313373253493015


601it [14:05,  1.06s/it]

processed: 601/2994 examples. accuracy: 0.4209650582362729


701it [16:45,  1.21s/it]

processed: 701/2994 examples. accuracy: 0.4094151212553495


801it [19:15,  1.13s/it]

processed: 801/2994 examples. accuracy: 0.41448189762796506


901it [21:34,  1.53s/it]

processed: 901/2994 examples. accuracy: 0.4239733629300777


1001it [24:06,  1.64s/it]

processed: 1001/2994 examples. accuracy: 0.4125874125874126


1101it [26:36,  1.55it/s]

processed: 1101/2994 examples. accuracy: 0.4123524069028156


1201it [29:05,  1.63s/it]

processed: 1201/2994 examples. accuracy: 0.4138218151540383


1301it [31:31,  1.27it/s]

processed: 1301/2994 examples. accuracy: 0.415065334358186


1401it [34:09,  1.42s/it]

processed: 1401/2994 examples. accuracy: 0.4182726623840114


1501it [36:38,  1.04it/s]

processed: 1501/2994 examples. accuracy: 0.4217188540972685


1601it [39:25,  1.95s/it]

processed: 1601/2994 examples. accuracy: 0.42286071205496567


1701it [41:44,  1.29it/s]

processed: 1701/2994 examples. accuracy: 0.42034097589653147


1801it [44:25,  1.40s/it]

processed: 1801/2994 examples. accuracy: 0.41921154913936703


1901it [46:58,  2.11s/it]

processed: 1901/2994 examples. accuracy: 0.4139926354550237


2001it [49:31,  1.11s/it]

processed: 2001/2994 examples. accuracy: 0.41529235382308843


2101it [52:00,  1.92s/it]

processed: 2101/2994 examples. accuracy: 0.4145644930985245


2201it [54:03,  1.13s/it]

processed: 2201/2994 examples. accuracy: 0.4134484325306679


2301it [56:43,  2.22s/it]

processed: 2301/2994 examples. accuracy: 0.41286397218600607


2401it [59:17,  2.27s/it]

processed: 2401/2994 examples. accuracy: 0.4148271553519367


2501it [1:01:24,  1.75s/it]

processed: 2501/2994 examples. accuracy: 0.41423430627748903


2601it [1:03:51,  1.54s/it]

processed: 2601/2994 examples. accuracy: 0.41753171856978083


2701it [1:06:08,  1.23s/it]

processed: 2701/2994 examples. accuracy: 0.4183635690485005


2801it [1:08:25,  1.29s/it]

processed: 2801/2994 examples. accuracy: 0.41949303820064265


2901it [1:10:39,  1.11s/it]

processed: 2901/2994 examples. accuracy: 0.41744226128921064


2994it [1:12:50,  1.46s/it]


acc = 0.41850367401469607
Coreference error: 0.11054421768707483
Entity error: 0.2188679245283019
Predicate error: 0.0740072202166065
Circumstance error: 0.24453406919974527
Link error: 0.06837606837606838
Unclear error: 0.14331210191082802
Done!


In [None]:

data_file = "/content/drive/MyDrive/factor/data/wiki_factor.csv"
output_folder = "/content/drive/MyDrive/factor/output"
model_name = 'facebook/opt-350m'
max_tokens = 1024
cache_dir = "/content/drive/MyDrive/factor/cache"
main(data_file, output_folder, model_name, max_tokens, cache_dir)

config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

torch.float16


2it [00:00,  3.72it/s]

processed: 1/2994 examples. accuracy: 1.0


102it [00:15,  8.54it/s]

processed: 101/2994 examples. accuracy: 0.37623762376237624


203it [00:29,  6.38it/s]

processed: 201/2994 examples. accuracy: 0.3482587064676617


302it [00:44,  6.72it/s]

processed: 301/2994 examples. accuracy: 0.34219269102990035


400it [00:59,  6.71it/s]

processed: 401/2994 examples. accuracy: 0.3266832917705736


502it [01:15,  5.05it/s]

processed: 501/2994 examples. accuracy: 0.3253493013972056


601it [01:30,  8.08it/s]

processed: 601/2994 examples. accuracy: 0.3161397670549085


701it [01:48,  7.45it/s]

processed: 701/2994 examples. accuracy: 0.3181169757489301


802it [02:04,  8.74it/s]

processed: 801/2994 examples. accuracy: 0.317103620474407


903it [02:18,  7.71it/s]

processed: 901/2994 examples. accuracy: 0.32297447280799113


1002it [02:35,  6.81it/s]

processed: 1001/2994 examples. accuracy: 0.3166833166833167


1103it [02:51, 12.71it/s]

processed: 1101/2994 examples. accuracy: 0.31607629427792916


1201it [03:06,  6.24it/s]

processed: 1201/2994 examples. accuracy: 0.3164029975020816


1302it [03:22, 11.39it/s]

processed: 1301/2994 examples. accuracy: 0.3143735588009224


1400it [03:38,  6.02it/s]

processed: 1401/2994 examples. accuracy: 0.3140613847251963


1503it [03:55, 11.00it/s]

processed: 1501/2994 examples. accuracy: 0.3157894736842105


1602it [04:13,  5.57it/s]

processed: 1601/2994 examples. accuracy: 0.3198001249219238


1701it [04:28,  8.48it/s]

processed: 1701/2994 examples. accuracy: 0.3186360964138742


1801it [04:46,  6.64it/s]

processed: 1801/2994 examples. accuracy: 0.31593559133814547


1901it [05:03,  4.92it/s]

processed: 1901/2994 examples. accuracy: 0.3145712782745923


2003it [05:21,  8.75it/s]

processed: 2001/2994 examples. accuracy: 0.3173413293353323


2102it [05:38,  5.98it/s]

processed: 2101/2994 examples. accuracy: 0.3165159447881961


2203it [05:50,  9.73it/s]

processed: 2201/2994 examples. accuracy: 0.3125851885506588


2301it [06:08,  4.39it/s]

processed: 2301/2994 examples. accuracy: 0.3098652759669709


2403it [06:26,  5.82it/s]

processed: 2401/2994 examples. accuracy: 0.31153685964181593


2503it [06:39,  7.56it/s]

processed: 2501/2994 examples. accuracy: 0.31067572970811674


2603it [06:55,  8.97it/s]

processed: 2601/2994 examples. accuracy: 0.31295655517108806


2701it [07:10,  7.13it/s]

processed: 2701/2994 examples. accuracy: 0.3113661606812292


2801it [07:26,  8.37it/s]

processed: 2801/2994 examples. accuracy: 0.31131738664762587


2901it [07:40,  9.09it/s]

processed: 2901/2994 examples. accuracy: 0.3098931402964495


2994it [07:55,  6.30it/s]


acc = 0.3102872411489646
Coreference error: 0.15306122448979592
Entity error: 0.2779874213836478
Predicate error: 0.10770156438026474
Circumstance error: 0.27297813627679895
Link error: 0.10256410256410256
Unclear error: 0.17834394904458598
Done!


In [None]:
import torch

torch.cuda.empty_cache()
