### Distill step by step finetuning approach - trying enhanced rationale with specific reasoning for date conversion

In [1]:
!sudo pip install -q transformers --upgrade
!sudo pip install -q peft

In [2]:
import transformers
transformers.__version__

'4.35.0'

In [3]:
import os
import torch
from datasets import load_dataset
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
import pandas as pd
import torch

In [4]:
# The model that you want to train from the Hugging Face hub
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

In [5]:
context2 = """{
    "MEASURE": [{'ENTITY': 'TRx', 'other names': ['total_prescriptions', 'overall_rx', 'complete_rx_count', 'full_prescription_volume', 'entire_rx_number']},
                {'ENTITY': 'NRx', 'other names': ['new_prescriptions', 'fresh_rx', 'recent_rx_count', 'initial_prescription_volume', 'first_rx_number']},
                {'ENTITY': 'NBRx', 'other names': ['new_to_brand_prescriptions', 'fresh_brand_rx', 'recent_brand_rx_count', 'initial_brand_prescription_volume', 'first_brand_rx_number']}],
    "DIMENSION": [{'ENTITY': 'Address', 'other names': ['location', 'street', 'residence', 'place', 'site']},
                  {'ENTITY': 'State', 'other names': ['province', 'region', 'territory', 'district', 'area']},
                  {'ENTITY': 'City', 'other names': ['town', 'municipality', 'urban_area', 'locality', 'metropolis']},
                  {'ENTITY': 'Zip_Code', 'other names': ['postal_code', 'zipcode', 'post_code', 'mailing_code', 'zip']},
                  {'ENTITY': 'Physician Name', 'other names': ['doctor_name', 'medical_practitioner', 'healthcare_provider', 'doc_fullname', 'practitioner']},
                  {'ENTITY': 'Specialty', 'other names': ['expertise', 'medical_field', 'healthcare_area', 'practice_focus', 'specialization']},
                  {'ENTITY': 'Specialty Group', 'other names': ['expertise_group', 'medical_field_category', 'healthcare_area_group', 'practice_focus_group', 'specialization_group']}],
    "FILTER": [{"ENTITY": "Healdsburg", "other names": [], "parent": "State"},
               {"ENTITY": "Brownsville", "other names": [], "parent": "State"},
               {"ENTITY": "Oncology", "other names": [], "parent": "Specialty"},
               {"ENTITY": "Pulmonary Disease", "other names": [], "parent": "Specialty"},
               {"ENTITY": "Cardiovascular Diseases", "other names": [], "parent": "Specialty"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Data Date", "other names": ["data date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""

In [6]:
date_input = {
    "start_date": "01/01/2019",
    "end_date": "10/11/2023"
}

In [7]:
torch.cuda.is_available()

True

In [8]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True,
                                          # add_eos_token=True,
                                          use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [9]:
from peft import PeftModel, PeftConfig

In [10]:
new_model_name = "/data/mistral/query-to-mql/exp-9/nov-01/checkpoint-4000"

In [11]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(new_model_name, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
#model.to('cuda')

In [14]:
query_template_v1 = """Given the context : {context} and date reference: {date_input}, the query: {user_query}, is converted into below shown structured output.
[MQL]
"""

In [15]:
def predict_template_query_v1(user_query):
    inp = query_template_v1.format(context=context2,
                                   user_query=user_query,
                                  date_input=date_input)
    _inputs = tokenizer.encode(inp, return_tensors="pt")
    outputs = model.generate(input_ids=_inputs.to('cuda'), max_length= 1600, pad_token_id=tokenizer.eos_token_id)
    output = tokenizer.decode(outputs[0])
    output_new = output.split('[MQL]\n')[1]
    return output_new.split('\n[/MQL]')[0], output
#     return output

In [17]:
test_df = pd.read_csv('/data/mistral/query-to-mql/exp-9/testing-data-exp-9-part2.csv')

In [18]:
test_df.head(2)

Unnamed: 0,Query
0,what is growth rate of new prescription in car...
1,what is correlation of new prescription for pu...


In [20]:
user_query_list = list(test_df['Query'])

In [21]:
import csv

In [22]:
from tqdm import tqdm

In [23]:
%%time 
data_fin = []
for user_query in tqdm(user_query_list):
    output, raw = predict_template_query_v1(user_query=user_query)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    data_fin.append([user_query,eval(output), steps])

with open('prediction_on_pharma_context-02.csv', 'a', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Query", "Intermediate MQL", "Reasoning"])

# Write data iteratively
    for row in data_fin[0:]:
        csvwriter.writerow(row)

100%|██████████| 24/24 [17:20<00:00, 43.35s/it]

CPU times: user 16min 50s, sys: 30.5 s, total: 17min 20s
Wall time: 17min 20s





In [16]:
%%time
user_query = 'what are total prescription in brownsville'
print('user query: ', user_query)
print('-'*100)
output, raw = predict_template_query_v1(user_query=user_query)
print(eval(output))
print('-'*100)
print('Step 1:' +raw.split('\nStep 1:')[1])

user query:  what are total prescription in brownsville
----------------------------------------------------------------------------------------------------


  next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)


{'MEASURE': {'total_prescriptions': [{'ENTITY': 'TRx', 'MEASURE CONSTRAINT': [{'COMPARISON VALUE': '', 'COMPARSION OPERATOR': ''}]}]}, 'FILTER': {'brownsville': [{'ENTITY': 'Brownsville', 'PARENT': 'State'}]}}
----------------------------------------------------------------------------------------------------
Step 1: Identify the components in the query
- The query asks for "total prescriptions" in "brownsville".

Step 2: Match the components to the context
- "total prescriptions" can be matched to the "TRx" entity in the MEASURE context.
- "brownsville" can be matched to the "Brownsville" entity in the FILTER context.

Step 3: Convert the query into a structured output
- Add the "TRx" entity from the MEASURE context to the structured output.
- Add the "Brownsville" entity from the FILTER context to the structured output.

Step 4: Check for date components
- The query does not have any date components, so the date reference is not utilized in this case.

Step 5: Review and validate the

In [7]:
# context1 = """{
#     "MEASURE": [{"ENTITY": "Insurance covergae", "other names": ["insurance amount", "total insurance coverage", "coverage"]},
#                 {"ENTITY": "Hospital bill", "other names": ["bill", "hospital expenses", "expenses"]},
#                 {"ENTITY": "Count", "other names": ["quantity", "counts"]}],
#     "DIMENSION": [{"ENTITY": "Disease", "other names": ["disease", "Diseases", "health issues"]},
#                   {"ENTITY": "State", "other names": ["segment", "segments", "units", "divisions"]},
#                   {"ENTITY": "Insurer", "other names": ["insurer", "insurance provider"]}],
#     "FILTER": [{"ENTITY": "Covid", "other names": ["covid-19", "covid19","Covid 19"], "parent": "Disease"},
#                {"ENTITY": "Cancer", "other names": ["cancers", "cancer", "tumour"], "parent": "Disease"},
#                {"ENTITY": "Delhi", "other names": ["New Delhi", "delhi"], "parent": "State"},
#                {"ENTITY": "Maharashtra", "other names": ["corporates", "corporate"], "parent": "State"},
#                {"ENTITY": "HDFC ergo", "other names": ["hdfc","HDFC","HDFC health insurance","hdfc insurance], "parent": "Insurer"},
#                {"ENTITY": "Aditya Birla", "other names": ["aditya birla health insurance","aditya birla insurance"], "parent": "Insurer"}],
#     "DERIVED MEASURE": [{"ENTITY": "Ratio",
#              "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
#             {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
#             {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
#             {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
#             {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
#             {"ENTITY": "correlation",
#              "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
#                              "correlation", "correlations", "relate", "related", "relation", "relations",
#                              "relationship",
#                              "relationships"]}
#             ],
#     "DATE VARIABLE": [{"ENTITY": "Admit Date", "other names": ["admit date", "date", "trend", "time", "when", "mom", "yoy"]}]
#     }"""

In [14]:
def inference(user_query):
    output, raw = predict_template_query_v1(user_query=user_query)
    mql = eval(output)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    return mql, steps

In [None]:
%%time
user_query = 'why sales changed in last 2 weeks of Nov 2021'
inference(user_query)

In [5]:
# context = """{
#     "MEASURE": [{"ENTITY": "Discount", "other names": ["discount", "discount rate", "discount value", "deduction"]},
#                 {"ENTITY": "Purchase Vol", "other names": ["purchase", "purchase value", "purchase model"]},
#                 {"ENTITY": "Quantity", "other names": ["quantity", "volume"]},
#                 {"ENTITY": "Sales", "other names": ["sales", "sale"]}],
#     "DIMENSION": [{"ENTITY": "Sub-Category", "other names": ["sub-category", "sub category", "categories", "section"]},
#                   {"ENTITY": "Segment", "other names": ["segment", "segments", "units", "divisions"]},
#                   {"ENTITY": "Parts", "other names": ["parts", "part", "section", "divisions"]},
#                   {"ENTITY": "Country", "other names": ["country", "countries"]}],
#     "FILTER": [{"ENTITY": "Consumer", "other names": ["consumers", "consumer"], "parent": "Segment"},
#                {"ENTITY": "Phone", "other names": ["phone", "phones", "mobile phones"], "parent": "Sub-Category"},
#                {"ENTITY": "Binder", "other names": ["binders", "binder"], "parent": "Sub-Category"},
#                {"ENTITY": "Corporate", "other names": ["corporates", "corporate"], "parent": "Segment"},
#                {"ENTITY": "India", "other names": ["india"], "parent": "Country"},
#                {"ENTITY": "Dubai", "other names": ["dubai"], "parent": "Country"}],
#     "DERIVED MEASURE": [{"ENTITY": "Ratio",
#              "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
#             {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
#             {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
#             {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
#             {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
#             {"ENTITY": "correlation",
#              "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
#                              "correlation", "correlations", "relate", "related", "relation", "relations",
#                              "relationship",
#                              "relationships"]}
#             ],
#     "DATE VARIABLE": [{"ENTITY": "Order Date", "other names": ["order date", "date", "trend", "time", "when", "mom", "yoy"]}]
#     }"""