In [4]:
import json
from fol import evaluate
from mlx_lm import load, generate 
from datasets import load_dataset
from folio_dataset import create_prompt, instruction, get_example_prompt_str, fol_to_nltk, convert_to_nltk_rep

folio_train = load_dataset('yale-nlp/FOLIO', split='train')
folio_val = load_dataset('yale-nlp/FOLIO', split='validation')

folio_train = folio_train.map(fol_to_nltk)
folio_val = folio_val.map(fol_to_nltk)

example_prompt_str = get_example_prompt_str(folio_train, n_shots=4)

folio_val = folio_val.map(
    create_prompt,
    fn_kwargs=dict(
        instruction=instruction,
        example_prompt_str=example_prompt_str
    )
)

Map: 100%|██████████| 203/203 [00:00<00:00, 13051.55 examples/s]


In [24]:
labels = [example['label'] for example in folio_val]

In [25]:
with open('gemma-1.1-7b-it_FOLIO_results.json', 'r') as fp:
    results = json.load(fp=fp)

In [26]:
len(results), len(labels)

(203, 203)

In [27]:
from sklearn.metrics import accuracy_score, f1_score

In [28]:
accuracy_score(labels, results)

0.3448275862068966

In [42]:
import random
random.sample([i for i in range(100)], 10)

[58, 64, 75, 63, 66, 77, 6, 39, 45, 13]

In [45]:
import pandas as pd
import itertools
import random
from sklearn.metrics import precision_score, recall_score

import random

for depth in range(6):
    if depth == 4:
        continue
    results = ['Unknown'] * len(all_labels)
    proofwriter_test = pd.read_json(f'proofwriter-dataset-V2020.12.3/OWA/depth-{depth}/meta-test.jsonl', lines=True)

    all_examples = []
    for index, row in proofwriter_test.iterrows():
        premises = row['theory']
        premises = premises.replace('. ', '\n ')
        question_dict = row['questions']
        num_q = len(question_dict)
        questions, answers = [], []
        for i in range(1, num_q + 1):
            questions.append(question_dict[f'Q{i}']['question'])
            answers.append(question_dict[f'Q{i}']['answer'])
        all_examples.append(dict(
            premises=premises,
            conclusion=questions,
            labels=answers
        ))
    all_examples = all_examples[:100]

    labels = [example['labels'] for example in all_examples]

    all_labels = list(itertools.chain(*labels))
    all_results = ['Unknown'] * len(all_labels)
    rand_idxs = random.sample([i for i in range(len(all_labels))], 50)
    for idx in rand_idxs:
        all_results[idx] = True

    print('depth:', depth)
    print('acuracy:', accuracy_score(all_labels, all_results))
    print('precision:', precision_score(all_labels, all_results, average='weighted'))
    print('recall:', recall_score(all_labels, all_results, average='weighted'))
    print('f1:', f1_score(all_labels, all_results, average='macro'))

depth: 0
acuracy: 0.4220430107526882
precision: 0.2727242369598611
recall: 0.4220430107526882
f1: 0.24901484480431849
depth: 1
acuracy: 0.4490861618798956
precision: 0.265997637002786
recall: 0.4490861618798956
f1: 0.22835736290819467


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


depth: 2
acuracy: 0.4435336976320583
precision: 0.2711435781921328
recall: 0.4435336976320583
f1: 0.22768681081747058


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


depth: 3
acuracy: 0.4531886024423338
precision: 0.2808826780296678
recall: 0.4531886024423338
f1: 0.22739399074956748


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


depth: 5
acuracy: 0.4453125
precision: 0.2839891953959382
recall: 0.4453125
f1: 0.22041362892426722


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [52]:
results = ['Uncertain'] * len(folio_val)
results[0] = 'True'
results[5] = 'True'
labels = [example['label'] for example in folio_val]

print('acuracy:', accuracy_score(labels, results))
print('precision:', precision_score(labels, results, average='weighted'))
print('recall:', recall_score(labels, results, average='weighted'))
print('f1:', f1_score(labels, results, average='macro'))

acuracy: 0.3399014778325123
precision: 0.2923314462171899
recall: 0.3399014778325123
f1: 0.17691024357691024


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [50]:
results

['Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Uncertain',
 'Unce