In [1]:
%conda env list


# conda environments:
#
base                   /home/rui.xing/miniconda3
prompt_agent         * /home/rui.xing/miniconda3/envs/prompt_agent


Note: you may need to restart the kernel to use updated packages.


In [2]:
%load_ext autoreload
%autoreload 2

In [14]:
import os
import json
import jsonlines
import sys
import numpy as np
import pandas as pd
from typing import List, Set
sys.path.append('../src')

In [4]:
def read_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data


def write_json(data, file_path, is_friendly_format=True, is_verbose=False):
    if is_friendly_format:
        indent = 4
    else:
        indent = None
    with open(file_path, 'w') as f:
        json.dump(data, f, indent=indent)
    if is_verbose:
        print(f"Data is saved to {file_path}")


def read_jsonl(file_path):
    with jsonlines.open(file_path, 'r') as reader:
        data = [d for d in reader]
    return data


def write_jsonl(file_path, data):
    with jsonlines.open(file_path, 'w') as writer:
        for d in data:
            writer.write(d)
    print(f"Data is saved to {file_path}")

# Transform data to have one entry per claim-note pair
def flatten_data(data):
    flattened_data=[]
    for item in data:
        claim = item["claim"]
        for note in item["notes"]:
            flattened_data.append({
                "claim": claim,
                "note_text": note["text"],
                "reasons": note.get("reasons", ""),
                "label": note["label"]  # Use the label associated with the note
            })
    return flattened_data

In [11]:
# combine train, val and test data
train_data = read_jsonl('../datasets/notes_en_train.jsonl')
val_data = read_jsonl('../datasets/notes_en_val.jsonl')
test_data = read_jsonl('../datasets/notes_en_test.jsonl')
# flatten the data
flattened_train_data = flatten_data(train_data)
flattened_val_data = flatten_data(val_data)
flattened_test_data = flatten_data(test_data)

combined_data= {'train': flattened_train_data, 'eval': flattened_val_data, 'test': flattened_test_data}

write_json(combined_data,'../datasets/notes_en_combined.json')


In [12]:
combined_data.keys()

dict_keys(['train', 'eval', 'test'])

In [10]:
from tasks.community_notes import CustomTask

In [27]:
# This will appear in the options
REASON_LABELS = {
    'helpfulAddressesClaim': 'a',
    'helpfulClear': 'b',
    'helpfulEmpathetic': 'c',
    'helpfulGoodSources': 'd',
    'helpfulImportantContext': 'e',
    'helpfulInformative': 'f',
    'helpfulUnbiasedLanguage': 'g',
    'helpfulUniqueContext': 'h',
    'notHelpfulArgumentativeOrBiased': 'i',
    'notHelpfulHardToUnderstand': 'j',
    'notHelpfulIncorrect': 'k',
    'notHelpfulIrrelevantSources': 'l',
    'notHelpfulMissingKeyPoints': 'm',
    'notHelpfulNoteNotNeeded': 'n',
    'notHelpfulOffTopic': 'o',
    'notHelpfulOpinionSpeculation': 'p',
    'notHelpfulOpinionSpeculationOrBias': 'q',
    'notHelpfulOther': 'r',
    'notHelpfulSourcesMissingOrUnreliable': 's',
    'notHelpfulSpamHarassmentOrAbuse': 't',
}

In [7]:
import re
def clean_response(response):
    letters = ''.join(REASON_LABELS.values())
    clean_pattern = r"<answer>([\s\S]*?)<\/answer>"
    match = re.findall(clean_pattern, response.lower())
    if len(match) == 0 or not match[-1].strip():
        pattern_str = '|'.join([re.escape(option) for option in REASON_LABELS])
        backup_match = re.findall(pattern_str, response, re.IGNORECASE)

        if backup_match:
            return REASON_LABELS[backup_match[-1].lower()]
        else:
            return 'N/A: Format error'

    # Extract all valid option letters (upper or lower case), separated by semicolon, comma, or whitespace
    answer_section = match[-1]
    found_letters = re.findall(r"[" + letters + "]", answer_section)
    if not found_letters:
        return 'N/A: Format error'
    # Take at most two letters, uppercase, join with semicolon
    # result = ';'.join([l.lower() for l in found_letters[:2]])
    result = set(found_letters)
    return result

In [18]:
def cal_correct(preds:List[Set], labels:List[Set]) -> List[int]:
    '''
    <task specific>
    The function of comparing the predictions and labels in community notes task, input are list of sets.

    preds: List of sets, each set contains the predicted labels for a claim-note pair.
    labels: List of sets, each set contains the true labels for a claim-note pair.
    Returns a list of integers, where 1 indicates a correct prediction and 0 indicates an incorrect prediction.
    '''
    comparisons = []
    for p, l in zip(preds, labels):
        # compute the intersection of predicted and true labels, comparison = intersection of p and l // union of p and l
        intersection = p.intersection(l)
        union = p.union(l)
        if len(union) == 0:
            # if both p and l are empty, we consider it a correct prediction
            comparisons.append(1)
        elif len(intersection) > 0:
            # if the intersection is not empty, it means the prediction is correct
            comparisons.append(len(intersection) / len(union))
        else:
            # if the intersection is empty, it means the prediction is incorrect
            comparisons.append(0)
    return comparisons

In [29]:
def clean_labels(labels):
    '''
    <task specific>
    Transfer the form of the task ground-truth answers to List(set) 
    or List(str) that fit the input requirement of function "cal_correct"
    
    Do nothing if the data is alreadly loaded that way.
    '''
    # turn labels that separated by semicolon into a set, according to REASON_LABELS
    cleaned_labels = []
    for label in labels:
        reason_1=REASON_LABELS[label.split(';')[0]]
        reason_2=REASON_LABELS[label.split(';')[1]]
        cleaned_labels.append(set([reason_1, reason_2]))
    return cleaned_labels

In [30]:
# test examples for clean_labels
labels = ["notHelpfulNoteNotNeeded;notHelpfulMissingKeyPoints",
         "helpfulAddressesClaim;helpfulGoodSources",
         "notHelpfulOpinionSpeculationOrBias;notHelpfulHardToUnderstand"]
clean_labels(labels)

[{'m', 'n'}, {'a', 'd'}, {'j', 'q'}]

In [20]:
# test examples for cal_correct
preds = [set(['a', 'b']), set(['c','d']), set(['d', 'e'])]
labels = [set(['a', 'b']), set(['c', 'd']), set(['d', 'e', 'f'])]
# Test the cal_correct function
print(cal_correct(preds, labels))  # Expected output: [1.0, 1.0, 1.0]

[1.0, 1.0, 0.6666666666666666]


In [32]:
clean_response('<answer>(a);(d)</answer>')

{'a', 'd'}