### Evaluation

In [18]:
import logging
from transformers import BitsAndBytesConfig
#Model configs

logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

model_name = "mistralai/Mistral-7B-Instruct-v0.3"
checkpoint_dir = "../output/checkpoints"
model_output_dir = "../output/final_adapter"
training_data = "../data/combinations.jsonl"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_quant_type="nf4"
)

In [19]:
! nvidia-smi

Wed Jul 30 23:44:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 576.40                 Driver Version: 576.40         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5070 Ti   WDDM  |   00000000:01:00.0 Off |                  N/A |
|  0%   38C    P8              8W /  300W |    8767MiB /  16303MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [20]:
from datasets import load_dataset, Dataset, DatasetDict


# load the dataset and split it into train, validation and test sets
dataset = load_dataset("json", data_files=training_data, split='train')
shuffled_dataset = dataset.shuffle(seed=42)
train_temp_split = shuffled_dataset.train_test_split(test_size=0.3) #30% for validation and test
temp_dataset = train_temp_split['test']
validation_test_split = temp_dataset.train_test_split(test_size=1/3)# 10% for validation and 20% for test
split_datasets = DatasetDict({
    'train': train_temp_split['train'],
    'validation': validation_test_split['train'],
    'test': validation_test_split['test']
})

In [21]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(model_name, 
                                                    quantization_config=quantization_config , 
                                                    device_map="auto")
    
model = PeftModel.from_pretrained(base_model, model_output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_output_dir)

INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
import json
import re
import numpy as np
import torch
from jsonschema import validate, ValidationError
from transformers import StoppingCriteria, StoppingCriteriaList


# function to extact JSON from a string
def find_and_parse_json(text: str) -> dict:

    try:
        # Find the first opening curly brace
        first_brace_index = text.find('{')
        if first_brace_index == -1:
            return None # No JSON object found

        # Start searching from the first brace
        brace_level = 1
        for i, char in enumerate(text[first_brace_index + 1:]):
            if char == '{':
                brace_level += 1
            elif char == '}':
                brace_level -= 1
            
            if brace_level == 0:
                # We found the matching closing brace
                last_brace_index = first_brace_index + i + 1
                json_string = text[first_brace_index : last_brace_index + 1]
                
                # Now, use the built-in json library to parse it
                return json.loads(json_string)
        
        return None # No complete JSON object found
    except (json.JSONDecodeError, IndexError):
        # Handle cases where the substring is not valid JSON or string is malformed
        return None


# Validate a JSON string against a schema
def is_valid_json_schema(json_generated, schema):
    try:
        validate(instance=json_generated, schema=schema)
        return True
    except (json.JSONDecodeError, ValidationError):
        return False
    
# Compare two JSON strings at field level and match field values
def compare_json_field_values(json_ref, json_gen):

    
    if not isinstance(json_ref, dict) or not isinstance(json_gen, dict):
        return 0.0
    
    score = 10
    # json_gen has all fields of json_ref, if not reduce the score proportional to no of fields missing
    if set(json_gen.keys()).issubset(set(json_ref.keys())):
        for key in json_ref.keys():
            if json_gen.get(key) is None:
                reduce_score = 3 / len(json_ref.keys())
                logger.debug(f"Key {key} is missing in json_gen, reducing score by {reduce_score}")
                score -= reduce_score
    else:
        reduce_score = (len(json_ref.keys()) - len(json_gen.keys())) * (3 / len(json_ref.keys()))
        logger.debug(f"Keys in json_ref not present in json_gen, reducing score by {reduce_score}")
        score -= reduce_score

    # Compare values of each field if they dont match reduce the score proportional to no of fields not matching in a scale of  0 to 5
    count_mismatch = 0
    for key in json_ref.keys():
        if key in json_gen:
            if json_ref[key] != json_gen[key]:
                count_mismatch += 1
        else:
            count_mismatch += 1
    count_mismatch/= len(json_ref.keys())
    logger.debug(f"Count of mismatched fields: {count_mismatch*2}")    
    score -= count_mismatch * 2

    # domainn specific checks
    if 'nm1_segment' in json_ref.keys():
        # Check if NM1 segment is present in both JSONs
        logger.debug(f"{json_ref['nm1_segment']}  {json_gen.get('nm1_segment')}")
        if 'nm1_segment' in json_gen.keys():
            #NM1*PR*2*Palmer*Darlene*M*Mr.*DVM*PI*SqLFzolMNdaVXE~
            pattern = r'^NM1(?:\*[^*]*){9}~$'
            value = json_gen[key]
            if not bool(re.match(pattern, value)):
                score -= 2.5
                logger.debug(f"NM1 segment {value} does not match the expected pattern, reducing score by 2.5")
            # check if all values in the json_gen are present in the value except for the nm1_segment
            for key in json_gen.keys():
                if key != 'nm1_segment' and json_gen[key] not in value:
                    score -= 2.5
                    logger.debug(f"Value {json_gen[key]} not found in NM1 segment {value}, reducing score by 2.5")
                    break
        else:
            score -= 5
            logger.debug(f"NM1 segment not found in json_gen, reducing score by 5")


    # If the score is less than 0, set it to 0
    if score < 0:
        score = 0.0

    logger.debug(f"Final score: {score}")

    return score

def get_result(valid_json, valid_json_schema, comparison_results):
    return json.dumps({
        "valid_json": np.mean(valid_json) * 100,
        "valid_json_schema": np.mean(valid_json_schema) * 100,
        "domain_field_match": np.mean(comparison_results) * 10
    })


def test_model(model, tokenizer, test_dataset):

    torch.cuda.empty_cache()

    valid_json = []
    valid_json_schema = []
    comparison_results = []

    stop_words = ["[INST]"]
    stop_word_ids = [tokenizer(stop_word, add_special_tokens=False).input_ids[0] for stop_word in stop_words]
    stop_word_ids.append(tokenizer.eos_token_id)  # Add EOS token ID to the stop words
    print(f"Tokens to stop on: {stop_words}")
    print(f"Corresponding Token IDs: {stop_word_ids}")
        
    for i, row in enumerate(test_dataset):

        print(f"Processing row {i+1}/{len(test_dataset)}")

        prompt = f"<s>[INST] {row['messages'][0]['content']} [/INST]"
        # print(f"Prompt: \n {prompt} \n ---------------------------------")
        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        
        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=1000,
                temperature=0.7,
                do_sample=True,                
                top_p=0.9,
                eos_token_id=stop_word_ids
            )

        generated_text = tokenizer.decode(output[0], skip_special_tokens=False)

        # print(f"Generated text: \n {generated_text} \n ---------------------------------")
        
        if("[/INST]" in generated_text):
            generated_json = find_and_parse_json(generated_text.split("[/INST]")[-1].strip())
            gold_json = find_and_parse_json(row['messages'][1]['content'])
            schema = find_and_parse_json(generated_text.split("[/INST]")[-2].strip())
            # print(f"Generated JSON: {generated_json} \n Goled JSON: {gold_json} \n Schema: {schema}")
        
        # is Json valid?
        if generated_json is not None:
            valid_json.append(True)
        else:
            valid_json.append(False)

        # is JSON valid against schema?
        if schema is not None:
            if generated_json is not None and is_valid_json_schema(generated_json, schema):
                valid_json_schema.append(True)
            else:
                valid_json_schema.append(False)
        
        # compare JSON field values
        if gold_json is not None:
            if generated_json is not None:
                comparison_score = compare_json_field_values(gold_json, generated_json)
            else:
                comparison_score = 0.0
            comparison_results.append(comparison_score)

        logger.info(get_result(valid_json, valid_json_schema, comparison_results))

    
    return get_result(valid_json, valid_json_schema, comparison_results)
    

In [23]:
logger.setLevel(logging.DEBUG) 
test_model(model, tokenizer, split_datasets["test"])

Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Tokens to stop on: ['[INST]']
Corresponding Token IDs: [3, 2]
Processing row 1/10


DEBUG - Keys in json_ref not present in json_gen, reducing score by -0.375
DEBUG - Count of mismatched fields: 2.0
DEBUG - NM1*IL*1*Willis*Cheryl***Jr.*XX*DFtlAnEqvtROQMDLhkjELWnql~  NM1*03*2*Garcia*Jennifer*D*Mr.**PI*JJcQJHXhjtPWwUHhJJUjWbXnUJMbjDm~
DEBUG - NM1 segment NM1*03*2*Garcia*Jennifer*D*Mr.**PI*JJcQJHXhjtPWwUHhJJUjWbXnUJMbjDm~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 5.875
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 58.75}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 2/10


DEBUG - Count of mismatched fields: 1.4
DEBUG - NM1*IL*1*Moss*Sean*B*Miss*PhD*PI*OCUrWoTpVGZHDEkEZgfTidcqAksi~  NM1*IL*1*Harris*Jeremy*R*Mrs.*Jr.*PI*KIUQKxTWcZDgBpqmEGdqwVyHrUvLKbW~
DEBUG - Final score: 8.6
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 72.375}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 3/10


DEBUG - Count of mismatched fields: 0.0
DEBUG - NM1*PR*2*Davis*Edward*V**DDS*PI*PCJkdlcPdwLVynEsUchRDDIyre~  NM1*PR*2*Davis*Edward*V**DDS*PI*PCJkdlcPdwLVynEsUchRDDIyre~
DEBUG - NM1 segment NM1*PR*2*Davis*Edward*V**DDS*PI*PCJkdlcPdwLVynEsUchRDDIyre~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 7.5
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 73.25}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 4/10


DEBUG - Count of mismatched fields: 0.0
DEBUG - NM1*IL*1*Stokes*Alejandro*U**DDS**~  NM1*IL*1*Stokes*Alejandro*U**DDS**~
DEBUG - NM1 segment NM1*IL*1*Stokes*Alejandro*U**DDS**~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 7.5
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 73.6875}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 5/10


DEBUG - Count of mismatched fields: 1.4
DEBUG - NM1*IL*2*Valencia*Brian*Y*Mx.*MD*MI*BVUAfvezYTDvNARKXtuGxTGTsVmeWIwSqftz~  NM1*IL*2*Rivera*Matthew*K*Mx.*DDS*PI*jNHrIWwXWYDGYzHwFqBtDmQv~
DEBUG - Final score: 8.6
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 76.15}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 6/10


DEBUG - Count of mismatched fields: 1.3333333333333333
DEBUG - NM1*03*1*Lambert****MD*XX*~  NM1*03*1*Caldwell**DDS*SV*~
DEBUG - NM1 segment NM1*03*1*Caldwell**DDS*SV*~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 6.166666666666666
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 73.73611111111111}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 7/10


DEBUG - Count of mismatched fields: 0.0
DEBUG - NM1*1P*1*Barrett**J*Mx.*MD**BTopyVU~  NM1*1P*1*Barrett**J*Mx.*MD**BTopyVU~
DEBUG - NM1 segment NM1*1P*1*Barrett**J*Mx.*MD**BTopyVU~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 7.5
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 73.91666666666667}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 8/10


DEBUG - Keys in json_ref not present in json_gen, reducing score by 0.0
DEBUG - Count of mismatched fields: 1.7777777777777777
DEBUG - NM1*PR*1*Evans*Jeffery*F**II*SV*DUvFMrOCGkdtxsAYMJhcWGSdaiunCBaZdaxEvcGuhJLMUrSm~  NM1*03*1*Murphy*Joseph*L*Ms.*DDS**GKOjcFwCeMzQRKgLQTfzUFQrPXKtQxYKwjUWwZHUHvzCgWxK~
DEBUG - NM1 segment NM1*03*1*Murphy*Joseph*L*Ms.*DDS**GKOjcFwCeMzQRKgLQTfzUFQrPXKtQxYKwjUWwZHUHvzCgWxK~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 5.722222222222221
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 71.82986111111111}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 9/10


DEBUG - Count of mismatched fields: 0.0
DEBUG - NM1*PR*2*Tucker*Michael*M*Mx.*MD*PI*~  NM1*PR*2*Tucker*Michael*M*Mx.*MD*PI*~
DEBUG - NM1 segment NM1*PR*2*Tucker*Michael*M*Mx.*MD*PI*~ does not match the expected pattern, reducing score by 2.5
DEBUG - Final score: 7.5
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 72.1820987654321}
Setting `pad_token_id` to `eos_token_id`:3 for open-end generation.


Processing row 10/10


DEBUG - Keys in json_ref not present in json_gen, reducing score by -0.3333333333333333
DEBUG - Count of mismatched fields: 1.7777777777777777
DEBUG - NM1*PR*2*Howell*Andrea*A*Mx.**PI*rhoxepNgGoHi~  NM1*PR*1*Hill*Jennifer*F*Mr.*Jr.*SV*NrFmYOYMnLUHlzcFUmJbIgfqoJZKtQpJwzRjrJJtXgmUJpMfO~
DEBUG - Final score: 8.555555555555557
INFO - {"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 73.51944444444445}


'{"valid_json": 100.0, "valid_json_schema": 100.0, "domain_field_match": 73.51944444444445}'