### Evaluation

In [1]:
import logging
from transformers import BitsAndBytesConfig
#Model configs

logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

model_name = "mistralai/Mistral-7B-Instruct-v0.3"
training_data = "../data/combinations.jsonl"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4"
)

In [2]:
! nvidia-smi

Tue Aug  5 09:26:09 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 576.40                 Driver Version: 576.40         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5070 Ti   WDDM  |   00000000:01:00.0 Off |                  N/A |
|  0%   37C    P8              7W /  300W |       0MiB /  16303MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
from datasets import load_dataset, Dataset, DatasetDict


# load the dataset and split it into train, validation and test sets
dataset = load_dataset("json", data_files=training_data, split='train')
shuffled_dataset = dataset.shuffle(seed=42)
train_temp_split = shuffled_dataset.train_test_split(test_size=0.3) #30% for validation and test
temp_dataset = train_temp_split['test']
validation_test_split = temp_dataset.train_test_split(test_size=1/3)# 10% for validation and 20% for test
split_datasets = DatasetDict({
    'train': train_temp_split['train'],
    'validation': validation_test_split['train'],
    'test': validation_test_split['test']
})

INFO - PyTorch version 2.7.0+cu128 available.


In [4]:
import os
import mlflow

def get_last_run_id():
    runs = mlflow.search_runs(order_by=["start_time desc"], max_results=1)
    if not runs.empty:
        return runs.iloc[0].run_id
    else:
        logger.warning("No runs found in MLflow.")
        return None

def get_run_artifact_uri(last_run_id):
    if last_run_id:
        run_info = mlflow.get_run(last_run_id)
        artifact_path = run_info.info.artifact_uri
        if artifact_path.startswith("file://"):
            artifact_path = artifact_path.replace("file://", "")
        elif artifact_path.startswith("file:"):
            artifact_path = artifact_path.replace("file:", "")
        if "/training/.." in artifact_path:
            artifact_path = artifact_path.replace("/training/..", "")
        artifact_path = artifact_path+ "/adapters"
        folders = [f for f in os.listdir(artifact_path) if os.path.isdir(os.path.join(artifact_path, f))]
        logger.info(f"Models available: {', '.join(folders)}")
        if folders:
            last_updated_folder = max(folders, key=lambda f: os.path.getmtime(os.path.join(artifact_path, f)))
            logger.info(f"Last updated model: {last_updated_folder}")
            return os.path.join(artifact_path, last_updated_folder)
        else:
            logger.warning("No folders found in the artifact URI.")
        return None
    else:
        logger.warning("No runs found in MLflow.")
        return None
    
mlflow.end_run()

mlflow_uri = "file:../mlruns"
mlflow_experiment_name = "Synthetic-data-generator-fine-tuning"

mlflow.set_tracking_uri(mlflow_uri)
mlflow.set_experiment(mlflow_experiment_name)

run_id = get_last_run_id()
mlflow.start_run(run_id=run_id)
adapter_dir = get_run_artifact_uri(run_id)

print(f"Last run ID: {run_id}")
print(f"Artifact URI: {adapter_dir}")

ValueError: invalid literal for int() with base 10: '1.0'

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                    quantization_config=quantization_config , 
                                                    device_map="auto")
    
model = PeftModel.from_pretrained(base_model, adapter_dir)
tokenizer = AutoTokenizer.from_pretrained(adapter_dir)

INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
import json
import re
import numpy as np
import torch
from jsonschema import validate, ValidationError
from transformers import StoppingCriteria, StoppingCriteriaList


# function to extact JSON from a string
def find_and_parse_json(text: str) -> dict:

    try:
        # Find the first opening curly brace
        first_brace_index = text.find('{')
        if first_brace_index == -1:
            return None # No JSON object found

        # Start searching from the first brace
        brace_level = 1
        for i, char in enumerate(text[first_brace_index + 1:]):
            if char == '{':
                brace_level += 1
            elif char == '}':
                brace_level -= 1
            
            if brace_level == 0:
                # We found the matching closing brace
                last_brace_index = first_brace_index + i + 1
                json_string = text[first_brace_index : last_brace_index + 1]
                
                # Now, use the built-in json library to parse it
                return json.loads(json_string)
        
        return None # No complete JSON object found
    except (json.JSONDecodeError, IndexError):
        # Handle cases where the substring is not valid JSON or string is malformed
        return None


# Validate a JSON string against a schema
def is_valid_json_schema(json_generated, schema):
    try:
        validate(instance=json_generated, schema=schema)
        return True
    except (json.JSONDecodeError, ValidationError):
        return False
    
# Compare two JSON strings at field level and match field values
def compare_json_field_values(json_ref, json_gen):

    
    if not isinstance(json_ref, dict) or not isinstance(json_gen, dict):
        return 0.0
    
    score = 10
    # json_gen has all fields of json_ref, if not reduce the score proportional to no of fields missing
    if set(json_gen.keys()).issubset(set(json_ref.keys())):
        for key in json_ref.keys():
            if json_gen.get(key) is None:
                reduce_score = 3 / len(json_ref.keys())
                logger.debug(f"Key {key} is missing in json_gen, reducing score by {reduce_score}")
                score -= reduce_score
    else:
        reduce_score = (len(json_ref.keys()) - len(json_gen.keys())) * (3 / len(json_ref.keys()))
        logger.debug(f"Keys in json_ref not present in json_gen, reducing score by {reduce_score}")
        score -= reduce_score

    # Compare values of each field if they dont match reduce the score proportional to no of fields not matching in a scale of  0 to 5
    count_mismatch = 0
    for key in json_ref.keys():
        if key in json_gen:
            if json_ref[key] != json_gen[key]:
                count_mismatch += 1
        else:
            count_mismatch += 1
    count_mismatch/= len(json_ref.keys())
    logger.debug(f"Count of mismatched fields: {count_mismatch*2}")    
    score -= count_mismatch * 2

    # domainn specific checks
    if 'nm1_segment' in json_ref.keys():
        # Check if NM1 segment is present in both JSONs
        logger.debug(f"{json_ref['nm1_segment']}  {json_gen.get('nm1_segment')}")
        if 'nm1_segment' in json_gen.keys():
            #NM1*PR*2*Palmer*Darlene*M*Mr.*DVM*PI*SqLFzolMNdaVXE~
            pattern = r'^NM1(?:\*[^*]*){9}~$'
            value = json_gen[key]
            if not bool(re.match(pattern, value)):
                score -= 2.5
                logger.debug(f"NM1 segment {value} does not match the expected pattern, reducing score by 2.5")
            # check if all values in the json_gen are present in the value except for the nm1_segment
            for key in json_gen.keys():
                if key != 'nm1_segment' and json_gen[key] not in value:
                    score -= 2.5
                    logger.debug(f"Value {json_gen[key]} not found in NM1 segment {value}, reducing score by 2.5")
                    break
        else:
            score -= 5
            logger.debug(f"NM1 segment not found in json_gen, reducing score by 5")


    # If the score is less than 0, set it to 0
    if score < 0:
        score = 0.0

    logger.debug(f"Final score: {score}")

    return score

def get_result(valid_json, valid_json_schema, comparison_results):
    return json.dumps({
        "valid_json": np.mean(valid_json) * 100,
        "valid_json_schema": np.mean(valid_json_schema) * 100,
        "domain_field_match": np.mean(comparison_results) * 10
    })


def test_model(model, tokenizer, test_dataset):

    torch.cuda.empty_cache()

    valid_json = []
    valid_json_schema = []
    comparison_results = []

    stop_words = ["[INST]"]
    stop_word_ids = [tokenizer(stop_word, add_special_tokens=False).input_ids[0] for stop_word in stop_words]
    stop_word_ids.append(tokenizer.eos_token_id)  # Add EOS token ID to the stop words
    print(f"Tokens to stop on: {stop_words}")
    print(f"Corresponding Token IDs: {stop_word_ids}")
        
    for i, row in enumerate(test_dataset):

        print(f"Processing row {i+1}/{len(test_dataset)}")

        prompt = f"<s>[INST] {row['messages'][0]['content']} [/INST]"
        # print(f"Prompt: \n {prompt} \n ---------------------------------")
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=1000,
                temperature=0.7,
                do_sample=True,                
                top_p=0.9,
                eos_token_id=stop_word_ids
            )

        generated_text = tokenizer.decode(output[0], skip_special_tokens=False)

        print(f"Generated text: \n {generated_text} \n ---------------------------------")
        
        if("[/INST]" in generated_text):
            generated_json = find_and_parse_json(generated_text.split("[/INST]")[-1].strip())
            gold_json = find_and_parse_json(row['messages'][1]['content'])
            schema = find_and_parse_json(generated_text.split("[/INST]")[-2].strip())
            # print(f"Generated JSON: {generated_json} \n Goled JSON: {gold_json} \n Schema: {schema}")
        
        # is Json valid?
        if generated_json is not None:
            valid_json.append(True)
        else:
            valid_json.append(False)

        # is JSON valid against schema?
        if schema is not None:
            if generated_json is not None and is_valid_json_schema(generated_json, schema):
                valid_json_schema.append(True)
            else:
                valid_json_schema.append(False)
        
        # compare JSON field values
        if gold_json is not None:
            if generated_json is not None:
                comparison_score = compare_json_field_values(gold_json, generated_json)
            else:
                comparison_score = 0.0
            comparison_results.append(comparison_score)

        logger.info(get_result(valid_json, valid_json_schema, comparison_results))

    
    return get_result(valid_json, valid_json_schema, comparison_results)
    

In [None]:
logger.setLevel(logging.DEBUG) 
result = test_model(model, tokenizer, split_datasets["test"])

result = json.loads(result)

epoch = re.search(r'(\d+)$', adapter_dir).group(1)
mlflow.log_metric("syntactic-accuracy", result['valid_json'], step=int(epoch))
mlflow.log_metric("schema-accuracy", result['valid_json_schema'], step=int(epoch))
mlflow.log_metric("domain-field-match", result['domain_field_match'], step=int(epoch))

mlflow.end_run()
print("Evaluation completed and results logged to MLflow.")

Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Tokens to stop on: ['[INST]']
Corresponding Token IDs: [3, 2]
Processing row 1/2


DEBUG - Keys in json_ref not present in json_gen, reducing score by -0.3333333333333333
DEBUG - Count of mismatched fields: 1.7777777777777777
DEBUG - NM1*03*2*Craig*Devon*L*Mr.**SV*YeIeNPmlhuDlJFAyOmPgYMJWYGMRgUGFIUcoEogcYnDgDAzrLhjyNOVTjyGkmIIBuMfcjiuEHQyCr~  NM1*IL*1*Smith*John M*Mr.**MI*123456789~
DEBUG - NM1 segment NM1*IL*1*Smith*John M*Mr.**MI*123456789~ does not match the expected pattern, reducing score by 2.5
DEBUG - Value Jr. not found in NM1 segment NM1*IL*1*Smith*John M*Mr.**MI*123456789~, reducing score by 2.5
DEBUG - Final score: 3.555555555555557
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 35.55555555555557}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Generated text: 
 <s><s>[INST] You are a Json test data generator generating valid json data based on a json schema. 
Given the Schema: {"$schema": "http://json-schema.org/draft-07/schema#", "title": "X12 NM1 Segment", "description": "Schema for X12 NM1 segment, which contains subscriber's or organization name information.", "type": "object", "properties": {"entity_identifier_code": {"type": "string", "description": "Entity Identifier Code (e.g., IL for Insured or 03 for dependent or PR for Payer or 1P for Provider)", "enum": ["IL", "03", "PR", "1P"]}, "entity_type_qualifier": {"type": "string", "description": "Entity Type Qualifier (1 = Person, 2 = Non-Person Entity)", "enum": ["1", "2"]}, "name_last_or_organizationName": {"type": "string", "description": "Name Last or Organization Name", "minLength": 1, "maxLength": 60}, "name_first": {"type": "string", "description": "Name First", "nullable": true, "minLength": 1, "maxLength": 35}, "name_middle": {"type": "string", "description": "N

DEBUG - Keys in json_ref not present in json_gen, reducing score by -0.75
DEBUG - Count of mismatched fields: 2.0
DEBUG - NM1*03*1*Reeves*Bruce*G***XX*LfxtCTLZh~  NM1*PR*2*Johnson*John M*Dr*MD*PI*123456789~
DEBUG - NM1 segment NM1*PR*2*Johnson*John M*Dr*MD*PI*123456789~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 6.25
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 49.027777777777786}


Generated text: 
 <s><s>[INST] You are a sample data generator that creates valid JSON data based on a schema.
Given the Schema: {"$schema": "http://json-schema.org/draft-07/schema#", "title": "X12 NM1 Segment", "description": "Schema for X12 NM1 segment, which contains subscriber's or organization name information.", "type": "object", "properties": {"entity_identifier_code": {"type": "string", "description": "Entity Identifier Code (e.g., IL for Insured or 03 for dependent or PR for Payer or 1P for Provider)", "enum": ["IL", "03", "PR", "1P"]}, "entity_type_qualifier": {"type": "string", "description": "Entity Type Qualifier (1 = Person, 2 = Non-Person Entity)", "enum": ["1", "2"]}, "name_last_or_organizationName": {"type": "string", "description": "Name Last or Organization Name", "minLength": 1, "maxLength": 60}, "name_first": {"type": "string", "description": "Name First", "nullable": true, "minLength": 1, "maxLength": 35}, "name_middle": {"type": "string", "description": "Name Mid

ValueError: invalid literal for int() with base 10: '1.0'