In [None]:
from collections import namedtuple, Counter
from lorax import Client
from pydantic import BaseModel
from tqdm import tqdm
from typing import List
import json
import os
import pandas as pd

In [None]:
prompt_template = """
Your task is a Named Entity Recognition (NER) task. Predict the category of
each entity, then place the entity into the list associated with the 
category in an output JSON payload. Below is an example:

Input: EU rejects German call to boycott British lamb . Output: {{"person":
[], "organization": ["EU"], "location": [], "miscellaneous": ["German",
"British"]}}

Now, complete the task.

Input: {input} Output:"""

In [None]:
# Define output schema using a Pydantic model
class Output(BaseModel):
    person: List[str]
    organization: List[str]
    location: List[str]
    miscellaneous: List[str]

# Schema and adapter ID vars for use later
response_schema = Output.schema()
adapter_id = "<ADAPTER_ID>"

In [None]:
# Load the dataset
dset = pd.read_csv("conllpp.csv")
test_set = dset[dset["split"] == 2]

In [None]:
# Instantiate LoRAX client
client = Client("http://127.0.0.1:8080")

def generate(input, adapter_id=None, schema=None):
    response_format = None if schema is None else {"type": "json_object", "schema": schema}

    return client.generate(
        prompt_template.format(input=input),
        adapter_id=adapter_id,
        adapter_source="s3",
        max_new_tokens=128,
        response_format=response_format,
        details=False,
    )

def benchmark(adapter_id=None, schema=None):
    model_label = "base" if adapter_id is None else "finetuned"
    schema_label = "raw" if schema is None else "constrained"
    output_file = f"benchmarks/{model_label}_{schema_label}.csv"

    print(f"=====BENCHMARK:{model_label}+{schema_label}=====\n")

    final_outputs = []
    failed_rows = []
    for idx, (_, row) in tqdm(enumerate(test_set.iterrows()), total=len(test_set), ncols=100):
        try:
            result = generate(row.input, adapter_id=adapter_id, schema=schema)
            final_outputs.append(result.generated_text)
        except:
            print(f"Failed generation for row {idx}")
            failed_rows.append(idx)

    print("\n========DONE========")
    print(f"Success: {len(final_outputs)} Failed: {len(failed_rows)}\n")

    print("Writing outputs to file...")
    final_outputs_df = pd.DataFrame(final_outputs)
    final_outputs_df.to_csv(output_file, header=["output"], index=False)


# Run all four benchmarking combinations
benchmark()
benchmark(schema=response_schema)
benchmark(adapter_id=adapter_id)
benchmark(adapter_id=adapter_id, schema=response_schema)

In [None]:
KEYS = {'person', 'organization', 'location', 'miscellaneous'}
ValidateJsonResult = namedtuple('Result', ['is_valid', 'reason'])

def validate_json(json_str):
    try:
        d = json.loads(json_str)
        d_keys = set(list(d.keys()))
        if d_keys == KEYS:
            return ValidateJsonResult(True, '')
        else:
            return ValidateJsonResult(False, f'missing json keys. expected: {KEYS}, got: :{d_keys}')
    except Exception as e:
        return ValidateJsonResult(False, f'invalid json (original_msg: {e}). got {json_str}')

def flatten_dict(d):
    res = []
    for k, v in d.items():
        assert isinstance(v, list)
        res.append(f'__{k.upper()}__::{v}')
    return res

def compute_jaccard(target_dict, pred_dict):
    target_entities = flatten_dict(target_dict)
    pred_entities = flatten_dict(pred_dict)
    target_count = Counter(target_entities)
    pred_count = Counter(pred_entities)

    num = 0
    den = 0
    all_keys = set(target_entities).union(set(pred_entities))
    for key in all_keys:
        num += min(target_count.get(key, 0), pred_count.get(key, 0))
        den += max(target_count.get(key, 0), pred_count.get(key, 0))
    return num / den

def calc_results(benchmark_file):
    is_valid_data = []
    scores = []
    preds_df = pd.read_csv(benchmark_file)
    for idx, row in preds_df.iterrows():
        validate_json_result = validate_json(row.output)
        is_valid_data.append({
            'is_valid': validate_json_result.is_valid,
            'reason': validate_json_result.reason,
        })

        score = 0
        if validate_json_result.is_valid:
            target_dict = json.loads(test_set.iloc[idx].output)
            pred_dict = json.loads(row.output)
            score = compute_jaccard(target_dict, pred_dict)

        scores.append(score)

    return is_valid_data, scores

# Calculate benchmarking results
results_df = pd.DataFrame()
reasons_df = pd.DataFrame()

for f in sorted(os.listdir("benchmarks")):
    if not f.endswith(".csv"):
        continue

    is_valid_data, scores = calc_results(os.path.join(benchmarks_dir, f))
    benchmark_name = f.split(".")[0]
    results_df[f"{benchmark_name}_score"] = scores
    reasons_df[f"{benchmark_name}_valid"] = [v["is_valid"] for v in is_valid_data]
    reasons_df[f"{benchmark_name}_reason"] = [v["reason"] for v in is_valid_data]

# Average Jaccard scores for each combination
results_df.mean(axis=0)