In [2]:
!pip install transformers accelerate jsonformer prettytable

Collecting accelerate
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jsonformer
  Downloading jsonformer-0.12.0-py3-none-any.whl (6.6 kB)
Installing collected packages: jsonformer, accelerate
Successfully installed accelerate-0.25.0 jsonformer-0.12.0


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from jsonformer import Jsonformer

print("Loading model and tokenizer...")
model_name = "databricks/dolly-v2-3b"
model = AutoModelForCausalLM.from_pretrained(model_name, use_cache=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, use_cache=True)
print("Loaded model and tokenizer")

Loading model and tokenizer...


config.json:   0%|          | 0.00/819 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/5.68G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/450 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

Loaded model and tokenizer


In [24]:
import json
import time
from tqdm import tqdm
from prettytable import PrettyTable
import numpy as np
import csv


class JSONBenchmark:
    def __init__(self, dataset_file):
        self.dataset = self.load_dataset(dataset_file)

    def load_dataset(self, dataset_file):
        with open(dataset_file, "r") as f:
            dataset = [json.loads(line) for line in f.readlines()]
        return dataset

    def generate_prompt(self, passage, schema):
        user_message = f"""{passage}

From the above passage, extract the following schema:
{schema}

Only output JSON with the allowed types."""
        prompt = f"""<>You only respond in JSON. You do not add text before. You do not add text after. Only JSON.<>[INST] {user_message} [/INST]"""
        return prompt

    def run(self, generate, **kwargs):
        evals = []
        for data in tqdm(self.dataset):
            evaluation = {}

            prompt = self.generate_prompt(data["passage"], data["schema"])
            start_time = time.time()
            result = generate(prompt, **kwargs)[0]["generated_text"].strip()
            time_taken = round(time.time() - start_time, 3)

            evaluation["generation"] = result
            evaluation["time_taken"] = time_taken

            # check if result is valid JSON
            try:
                json_result = json.loads(result)
                evaluation["is_valid"] = True

                # check if result matches schema
                # JSON might have erroneous keys
                schema = data["extracted_data"]
                evaluation["matches_schema"] = json_result == schema
                evaluation["error_type"] = None
            except ValueError:
                evaluation["is_valid"] = False
                evaluation["matches_schema"] = False

                if result[0] != "{":
                    evaluation["error_type"] = "prefix"
                elif result[-1] != "}":
                    evaluation["error_type"] = "suffix"
                else:
                    evaluation["error_type"] = "invalid"

            evals.append(evaluation)

        return evals

    def print(self, results, show_generation=False):
        table = PrettyTable()

        # Define the table columns
        table.field_names = [
            "Valid (✅/❌)",
            "Matches Schema (✅/❌)",
            "Time (s)",
            "Error",
        ]
        if show_generation:
            table.add_column("Generation")

        valid_counter, schema_counter, total_time = 0, 0, 0

        for result in results:
            is_valid = "✅" if result["is_valid"] else "❌"
            matches_schema = "✅" if result["matches_schema"] else "❌"
            error_type = result["error_type"]

            valid_counter += result["is_valid"]
            schema_counter += result["matches_schema"]
            total_time += result["time_taken"]

            row = [is_valid, matches_schema, result["time_taken"], error_type]
            if show_generation:
                row.append(result["generation"])

            table.add_row(row)

        valid_accuracy = valid_counter / len(results)
        schema_accuracy = schema_counter / len(results)
        average_time = round(total_time / len(results), 3)

        table.add_row(["-", "-", "-", "-"])
        table.add_row(
            [
                f"Accuracy: {valid_accuracy}",
                f"Accuracy: {schema_accuracy}",
                f"Average: {average_time}",
                "-",
            ]
        )

        print(table)

In [19]:
class JSONFormerBenchmark(JSONBenchmark):
  def __init__(self, dataset_file):
    super().__init__(dataset_file)
    self.SUPPORTED_FIELD_TYPES = [
        "string",
        "number",
        "boolean",
        "array",
        "object"
    ]

  def convert_schema_to_jsonformer_format(self, schema):
    jsonformer_schema = {}
    jsonformer_schema["type"] = "object"
    jsonformer_schema["properties"] = {}
    for key in schema.keys():
      if type(schema[key]) is dict:
          jsonformer_schema["properties"][key] = self.convert_schema_to_jsonformer_format(schema[key])
      else:
        if schema[key] not in self.SUPPORTED_FIELD_TYPES:
          jsonformer_schema["properties"][key] = {"type" : "number"}
        else:
          jsonformer_schema["properties"][key] = {"type" : schema[key]}

    return jsonformer_schema

  def filter_schemas(self, schema):
    new_schema = {}
    for key in schema.keys():
      if key == "type":
        if schema[key] == "int":
          new_schema[key] = "number"
        elif schema[key] == "str":
           new_schema[key] = "string"
        else:
          new_schema[key] = schema[key]
      else:
        new_schema[key] = self.filter_schemas(schema[key])

    return new_schema

  def has_matching_schema(self, output, target):
    output_keys = output.keys()
    target_keys = output.keys()

    if output_keys != target.keys():
      return False

    else:
      for key in output_keys:
        if type(output[key]) is dict:
          if not self.has_matching_schema(output[key], target[key]):
            return False

    return True

  def generate_prompt(self, passage):
        prompt = f"""{passage}

From the above passage, generate information based on the following schema:"""
        return prompt

  def run(self, **kwargs):
    evals = []
    for data in tqdm(self.dataset):
        evaluation = {}

        prompt = self.generate_prompt(data["passage"])
        #json_schema = self.convert_schema_to_jsonformer_format(data["schema"])
        json_schema = self.filter_schemas(data["schema"])
        start_time = time.time()
        builder = Jsonformer(
            model=model,
            tokenizer=tokenizer,
            json_schema=json_schema,
            prompt=prompt
        )

        result_dict = builder()
        result = str(result_dict).replace("\'", "\"")

        time_taken = round(time.time() - start_time, 3)

        evaluation["generation"] = result
        evaluation["time_taken"] = time_taken

        # check if result is valid JSON
        try:
            json_result = json.loads(result)
            evaluation["is_valid"] = True

            # check if result matches schema
            # JSON might have erroneous keys
            schema = data["schema"]
            jsonformat_result_dict = self.convert_schema_to_jsonformer_format(result_dict)
            evaluation["matches_schema"] = self.has_matching_schema(jsonformat_result_dict, schema)
            evaluation["error_type"] = None
        except ValueError:
            evaluation["is_valid"] = False
            evaluation["matches_schema"] = False

            if result[0] != "{":
                evaluation["error_type"] = "prefix"
            elif result[-1] != "}":
                evaluation["error_type"] = "suffix"
            else:
                evaluation["error_type"] = "invalid"

        evals.append(evaluation)

    print(evals)
    return evals


In [20]:
sampling_params = {
    "do_sample": True,
    "top_k": 10,
    "num_return_sequences": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "max_length": 512,
    "return_full_text": False
}

In [21]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/MyDrive/madlibs_test_v2/madlibs/example-jsons

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
[Errno 2] No such file or directory: 'gdrive/MyDrive/madlibs_test_v2/madlibs/example-jsons'
/content/gdrive/MyDrive/madlibs_test_v2/madlibs/example-jsons


In [22]:
def run_iters(num_iters, eval_harness, out_file):

  all_outputs = []
  for _ in tqdm(range(num_iters)):
    outputs = eval_harness.run(**sampling_params)
    all_outputs.append(outputs)

  with open(out_file, 'w') as f:
    writer = csv.writer(f)
    writer.writerow(list(all_outputs[0][0].keys()))

    for iteration in all_outputs:
      for output in iteration:
        writer.writerow(output.values())
    f.close()

In [23]:
num_iters = 1
out_file = f'jsonformer-dolly-{num_iters}_iters.csv'
eval_harness = JSONFormerBenchmark("jsonbench.jsonl")
run_iters(num_iters, eval_harness, out_file)

UnicodeDecodeError: ignored