### Distill step by step finetuning approach - trying enhanced rationale with specific reasoning for date conversion

In [1]:
!sudo pip install -q transformers --upgrade
!sudo pip install -q peft

In [2]:
import transformers
transformers.__version__

In [3]:
import os
import torch
from datasets import load_dataset
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
import pandas as pd
import torch

In [4]:
# The model that you want to train from the Hugging Face hub
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

In [5]:
context_ecom = """{
    "MEASURE": [{"ENTITY": "Discount", "other names": ["discount", "discount rate", "discount value", "deduction"]},
                {"ENTITY": "Purchase Vol", "other names": ["purchase", "purchase value", "purchase model"]},
                {"ENTITY": "Quantity", "other names": ["quantity", "volume"]},
                {"ENTITY": "Sales", "other names": ["sales", "sale"]}],
    "DIMENSION": [{"ENTITY": "Sub-Category", "other names": ["sub-category", "sub category", "categories", "section"]},
                  {"ENTITY": "Segment", "other names": ["segment", "segments", "units", "divisions"]},
                  {"ENTITY": "Parts", "other names": ["parts", "part", "section", "divisions"]},
                  {"ENTITY": "Country", "other names": ["country", "countries"]}],
    "FILTER": [{"ENTITY": "Consumer", "other names": ["consumers", "consumer"], "parent": "Segment"},
               {"ENTITY": "Phone", "other names": ["phone", "phones", "mobile phones"], "parent": "Sub-Category"},
               {"ENTITY": "Binder", "other names": ["binders", "binder"], "parent": "Sub-Category"},
               {"ENTITY": "Corporate", "other names": ["corporates", "corporate"], "parent": "Segment"},
               {"ENTITY": "India", "other names": ["india"], "parent": "Country"},
               {"ENTITY": "Dubai", "other names": ["dubai"], "parent": "Country"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Order Date", "other names": ["order date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""
context = """{
    "MEASURE": [{"ENTITY": "TRx", "other names": ["total_prescriptions", "overall_rx", "complete_rx_count", "full_prescription_volume", "entire_rx_number"]},
                {"ENTITY": "NRx", "other names": ["new_prescriptions", "fresh_rx", "recent_rx_count", "initial_prescription_volume", "first_rx_number"]},
                {"ENTITY": "NBRx", "other names": ["new_to_brand_prescriptions", "fresh_brand_rx", "recent_brand_rx_count", "initial_brand_prescription_volume", "first_brand_rx_number"]},
                {"ENTITY": "NTS", "other names": ["new_to_specialty", "fresh_specialty_patients", "recent_specialty_count", "initial_specialty_volume", "first_specialty_number"]},
                {"ENTITY": "Switch", "other names": ["transition", "change", "shift", "swap", "alteration"]}],
    "DIMENSION": [{"ENTITY": "Physician ID", "other names": ["doctor_identifier", "medical_practitioner_id", "healthcare_provider_id", "doc_id", "practitioner_code"]},
                  {"ENTITY": "IMS_ID", "other names": ["ims_identifier", "ims_code", "ims_number", "ims_reference", "ims_key"]},
                  {"ENTITY": "NPI ID", "other names": ["national_provider_id", "npi_number", "npi_code", "npi_identifier", "npi_key"]},
                  {"ENTITY": "Address", "other names": ["location", "street", "residence", "place", "site"]},
                  {"ENTITY": "State", "other names": ["province", "region", "territory", "district", "area"]},
                  {"ENTITY": "City", "other names": ["town", "municipality", "urban_area", "locality", "metropolis"]},
                  {"ENTITY": "Zip_Code", "other names": ["postal_code", "zipcode", "post_code", "mailing_code", "zip"]},
                  {"ENTITY": "Physician Name", "other names": ["doctor_name", "medical_practitioner", "healthcare_provider", "doc_fullname", "practitioner"]},
                  {"ENTITY": "Specialty", "other names": ["expertise", "medical_field", "healthcare_area", "practice_focus", "specialization"]},
                  {"ENTITY": "Specialty Group", "other names": ["expertise_group", "medical_field_category", "healthcare_area_group", "practice_focus_group", "specialization_group"]},
                  {"ENTITY": "Brand", "other names": ["product", "trademark", "label", "make", "marque"]},
                  {"ENTITY": "Therapy Area", "other names": ["treatment_field", "therapy_domain", "care_area", "intervention_zone", "healing_sector"]},
                  {"ENTITY": "Market", "other names": ["industry", "sector", "commerce", "trade", "business_area"]},
                  {"ENTITY": "Payer Channel", "other names": ["payment_method", "insurance_type", "reimbursement_channel", "coverage_mode", "financial_route"]},
                  {"ENTITY": "Payer", "other names": ["insurer", "payment_provider", "coverage_source", "financial_institution", "insurance_company"]},
                  {"ENTITY": "Zip Code", "other names": ["postal_code", "zipcode", "post_code", "mailing_code", "zip"]},
                  {"ENTITY": "Territory", "other names": ["domain", "area", "zone", "region", "expanse"]},
                  {"ENTITY": "Region", "other names": ["geographical_area", "locale", "district", "sector", "division"]},
                  {"ENTITY": "District", "other names": ["administrative_area", "territory", "region", "zone", "jurisdiction"]},
                  {"ENTITY": "Sales Force", "other names": ["sales_team", "sales_representatives", "sales_agents", "sales_personnel", "sales_staff"]}],
    "FILTER": [{"ENTITY": "Healdsburg", "other names": [], "parent": "State"},
               {"ENTITY": "Brownsville", "other names": [], "parent": "State"},
               {"ENTITY": "Oncology", "other names": [], "parent": "Specialty"},
               {"ENTITY": "Pulmonary Disease", "other names": [], "parent": "Specialty"},
               {"ENTITY": "Cardiovascular Diseases", "other names": [], "parent": "Specialty"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Data Date", "other names": ["data date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""

In [6]:
date_input = {
    "start_date": "01/01/2020",
    "end_date": "15/09/2023"
}

In [7]:
torch.cuda.is_available()

In [8]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True,
                                          # add_eos_token=True,
                                          use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [9]:
from peft import PeftModel, PeftConfig

In [10]:
new_model_name = "/data/mistral/query-to-mql/exp-9/nov-01/checkpoint-4000"

In [11]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(new_model_name, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

In [12]:
query_template_v1 = """Given the context : {context} and date reference: {date_input}, the query: {user_query}, is converted into below shown structured output.
[MQL]
"""

In [13]:
def predict_template_query_v1(user_query):
    inp = query_template_v1.format(context=context,
                                   user_query=user_query,
                                  date_input=date_input)
    _inputs = tokenizer.encode(inp, return_tensors="pt")
    outputs = model.generate(input_ids=_inputs.to('cuda'), max_length= 1600, pad_token_id=tokenizer.eos_token_id)
    output = tokenizer.decode(outputs[0])
    output_new = output.split('[MQL]\n')[1]
    return output_new.split('\n[/MQL]')[0], output
#     return output

In [14]:
def inference(user_query):
    output, raw = predict_template_query_v1(user_query=user_query)
    mql = eval(output)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    return mql, steps

In [15]:
list_1 = ['how have the full_prescription_volume trended']

In [20]:
%%time 
data_fin = []
for user_query in list_1:
    print('user query: ', user_query)
    print('-'*100)
    output, raw = predict_template_query_v1(user_query=user_query)
    print(eval(output))
    print('-'*100)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    print('Step 1:' +raw.split('\nStep 1:')[1])
    print('-'*100)
    data_fin.append([user_query,eval(output)])
import csv
with open('data_1.csv', 'a', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Query", "Intermediate MQL"])

# Write data iteratively
    for row in data_fin[0:]:
        csvwriter.writerow(row)

In [17]:
user_query = "sales in jan 2020 versus year ago"
print('user query: ', user_query)
print('-'*100)
output, raw = predict_template_query_v1(user_query=user_query)
print(eval(outut))print('-'*100)
steps = 'Step 1:' +raw.split('\nStep 1:')[1]
print('Step 1:' +raw.split('\nStep 1:')[1])
print('-'*100)

In [None]:
#model.to("cuda")