### Distill step by step finetuning approach - trying enhanced rationale with specific reasoning for date conversion

In [1]:
!sudo pip install -q transformers --upgrade
!sudo pip install -q peft

In [2]:
import transformers
transformers.__version__

'4.35.0'

In [3]:
import os
import torch
from datasets import load_dataset
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
import pandas as pd
import torch

In [4]:
# The model that you want to train from the Hugging Face hub
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

In [5]:
context_ecom = """{
    "MEASURE": [{"ENTITY": "Discount", "other names": ["discount", "discount rate", "discount value", "deduction"]},
                {"ENTITY": "Purchase Vol", "other names": ["purchase", "purchase value", "purchase model"]},
                {"ENTITY": "Quantity", "other names": ["quantity", "volume"]},
                {"ENTITY": "Sales", "other names": ["sales", "sale"]}],
    "DIMENSION": [{"ENTITY": "Sub-Category", "other names": ["sub-category", "sub category", "categories", "section"]},
                  {"ENTITY": "Segment", "other names": ["segment", "segments", "units", "divisions"]},
                  {"ENTITY": "Parts", "other names": ["parts", "part", "section", "divisions"]},
                  {"ENTITY": "Country", "other names": ["country", "countries"]}],
    "FILTER": [{"ENTITY": "Consumer", "other names": ["consumers", "consumer"], "parent": "Segment"},
               {"ENTITY": "Phone", "other names": ["phone", "phones", "mobile phones"], "parent": "Sub-Category"},
               {"ENTITY": "Binder", "other names": ["binders", "binder"], "parent": "Sub-Category"},
               {"ENTITY": "Corporate", "other names": ["corporates", "corporate"], "parent": "Segment"},
               {"ENTITY": "India", "other names": ["india"], "parent": "Country"},
               {"ENTITY": "Dubai", "other names": ["dubai"], "parent": "Country"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Order Date", "other names": ["order date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""
context = """{
    "MEASURE": [{"ENTITY": "TRx", "other names": ["total_prescriptions", "overall_rx", "complete_rx_count", "full_prescription_volume", "entire_rx_number"]},
                {"ENTITY": "NRx", "other names": ["new_prescriptions", "fresh_rx", "recent_rx_count", "initial_prescription_volume", "first_rx_number"]},
                {"ENTITY": "NBRx", "other names": ["new_to_brand_prescriptions", "fresh_brand_rx", "recent_brand_rx_count", "initial_brand_prescription_volume", "first_brand_rx_number"]},
                {"ENTITY": "NTS", "other names": ["new_to_specialty", "fresh_specialty_patients", "recent_specialty_count", "initial_specialty_volume", "first_specialty_number"]},
                {"ENTITY": "Switch", "other names": ["transition", "change", "shift", "swap", "alteration"]}],
    "DIMENSION": [{"ENTITY": "Physician ID", "other names": ["doctor_identifier", "medical_practitioner_id", "healthcare_provider_id", "doc_id", "practitioner_code"]},
                  {"ENTITY": "IMS_ID", "other names": ["ims_identifier", "ims_code", "ims_number", "ims_reference", "ims_key"]},
                  {"ENTITY": "NPI ID", "other names": ["national_provider_id", "npi_number", "npi_code", "npi_identifier", "npi_key"]},
                  {"ENTITY": "Address", "other names": ["location", "street", "residence", "place", "site"]},
                  {"ENTITY": "State", "other names": ["province", "region", "territory", "district", "area"]},
                  {"ENTITY": "City", "other names": ["town", "municipality", "urban_area", "locality", "metropolis"]},
                  {"ENTITY": "Zip_Code", "other names": ["postal_code", "zipcode", "post_code", "mailing_code", "zip"]},
                  {"ENTITY": "Physician Name", "other names": ["doctor_name", "medical_practitioner", "healthcare_provider", "doc_fullname", "practitioner"]},
                  {"ENTITY": "Specialty", "other names": ["expertise", "medical_field", "healthcare_area", "practice_focus", "specialization"]},
                  {"ENTITY": "Specialty Group", "other names": ["expertise_group", "medical_field_category", "healthcare_area_group", "practice_focus_group", "specialization_group"]},
                  {"ENTITY": "Brand", "other names": ["product", "trademark", "label", "make", "marque"]},
                  {"ENTITY": "Therapy Area", "other names": ["treatment_field", "therapy_domain", "care_area", "intervention_zone", "healing_sector"]},
                  {"ENTITY": "Market", "other names": ["industry", "sector", "commerce", "trade", "business_area"]},
                  {"ENTITY": "Payer Channel", "other names": ["payment_method", "insurance_type", "reimbursement_channel", "coverage_mode", "financial_route"]},
                  {"ENTITY": "Payer", "other names": ["insurer", "payment_provider", "coverage_source", "financial_institution", "insurance_company"]},
                  {"ENTITY": "Zip Code", "other names": ["postal_code", "zipcode", "post_code", "mailing_code", "zip"]},
                  {"ENTITY": "Territory", "other names": ["domain", "area", "zone", "region", "expanse"]},
                  {"ENTITY": "Region", "other names": ["geographical_area", "locale", "district", "sector", "division"]},
                  {"ENTITY": "District", "other names": ["administrative_area", "territory", "region", "zone", "jurisdiction"]},
                  {"ENTITY": "Sales Force", "other names": ["sales_team", "sales_representatives", "sales_agents", "sales_personnel", "sales_staff"]}],
    "FILTER": [{"ENTITY": "Healdsburg", "other names": [], "parent": "State"},
               {"ENTITY": "Brownsville", "other names": [], "parent": "State"},
               {"ENTITY": "Oncology", "other names": [], "parent": "Specialty"},
               {"ENTITY": "Pulmonary Disease", "other names": [], "parent": "Specialty"},
               {"ENTITY": "Cardiovascular Diseases", "other names": [], "parent": "Specialty"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Data Date", "other names": ["data date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""

In [6]:
date_input = {
    "start_date": "01/01/2020",
    "end_date": "15/09/2023"
}

In [7]:
torch.cuda.is_available()

True

In [8]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True,
                                          # add_eos_token=True,
                                          use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [9]:
from peft import PeftModel, PeftConfig

In [10]:
new_model_name = "/data/mistral/query-to-mql/exp-9/nov-01/checkpoint-4000"

In [11]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(new_model_name, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
query_template_v1 = """Given the context : {context} and date reference: {date_input}, the query: {user_query}, is converted into below shown structured output.
[MQL]
"""

In [13]:
def predict_template_query_v1(user_query):
    inp = query_template_v1.format(context=context,
                                   user_query=user_query,
                                  date_input=date_input)
    _inputs = tokenizer.encode(inp, return_tensors="pt")
    outputs = model.generate(input_ids=_inputs.to('cuda'), max_length= 1700, pad_token_id=tokenizer.eos_token_id)
    output = tokenizer.decode(outputs[0])
    output_new = output.split('[MQL]\n')[1]
    return output_new.split('\n[/MQL]')[0], output
#     return output

In [14]:
def inference(user_query):
    output, raw = predict_template_query_v1(user_query=user_query)
    mql = eval(output)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    return mql, steps

In [18]:
user_query_list = ["how have the full_prescription_volume trended", "what is the monthly trend of TRx and full_prescription_volume", "what is the average overall_rx offered across change", "complete_rx_count vs entire_rx_number over the last 6 months", "Which are the top 5 entire_rx_number making therapy_domain", "Which TRx has given most percentage Sales Force", "first_brand_rx_number in march 2020 ", "Zip Code in 20 April  2021", "territory in 20 April 2020 ", "Territory of art MOM ", "initial_prescription_volume of art YOY ", "What is the growth rate of specialization_group", "what Zip_Code are contributing to growth for initial_brand_prescription_volume in p3m", "mtd growth rate", "ytd growth rate", "How many initial_specialty_volume have  new_prescriptions more than 10k", "what is the contribution of expanse by Physician ID in US", "zipcode in jan 2022 vs  year ago", "first_brand_rx_number in 2020 compared to prior year", "total_prescriptions in 2021 vs 2020", "growth rate of Zip Code MOM in 2020", "Switch in 2021", "npi_key in 2022", "what is th growth rate of make across years", "what is the growth of overall_rx across years", "What is the growth rate of fresh_rx", "what complete_rx_count are contributing to growth for NTS in p3m", "growth rate of ims_key across month", "what Switch are contributing to growth for recent_specialty_count in p3m for top 50%", "How many Switch are there", "How many NTS are there with doctor_identifier more than 35k", "How many entire_rx_number are contributing to 50% of npi_key", "How many municipality are contributing to growth to 50% of domain in p3m vs pp", "ytd Payer Channel by month", "ytd complete_rx_count  by quarter", "mtd Switch", "full_prescription_volume in 2020 YTD vs YA", "How have the District trended by Specialty ?", "what is the monthly trend of entire_rx_number and recent_specialty_count by state?", "domain vs healthcare_provider_id over the last 6 months by sector", "What is the complete_rx_count trend of treatment_field", "What is the first_rx_number trend of swap and geographical_area", "what is the ims_key across months", "how have doc_fullname trended", "Which full_prescription_volume has site greater than fifty thousand in 2020", "Which metropolis has ims_identifier more than 50000 in 2020", "Which zone has first_specialty_number above 50K in 2020", "Which specialization_group has make greater than 400K in 2019", "Which Switch has ims_number more than 20k in February 2019", "Which zipcode has Physician ID above 1K in india in 2019", "Which post_code has residence greater than 55.82 K in 2020", "Which total_prescriptions has medical_practitioner_id more than 4k in December 2019 in india", "Which doctor_identifier has fresh_rx greater than 3K in 2019 in india", "which month has zip greater than five hundred thousand in 2020", "Which region has change greater than 2K for territory February 2019", "Which month has fresh_rx above 50K for locality in 2019", "Which month has Therapy Area less than 200K in india in 2019", "which medical_practitioner has first_rx_number more than 10k in March 2019", "which month has ims_code trend for new_to_specialty greater than 50K in 2019", "Which jurisdiction has change above 20K in february 2019", "Which place has Switch less than 50K in previous  month", "Which coverage_mode has recent_specialty_count more  than  300000 in p6m", "Which practitioner_code has  initial_specialty_volume above 20k in february 2019", "which coverage_mode has sales_personnel below four thousand in february 2019", "Which Specialty has Switch below 10k in 2019", "Which new_to_specialty has TRx between 300k and 400k", "Which TRx has medical_field between 300k & 400k", "Which Specialty has healthcare_provider_id from 300k to 400k", "Which insurance_type has jurisdiction from 300k-400k", "when was npi_identifier the highest ", "when was the NRx of total_prescriptions was highest", "when was the Switch the highest", "when was growth in Switch the highest", "Top 3 Brand by alteration", "Top 3 mailing_code by new_to_brand_prescriptions across overall_rx", "Top 3 months by trade", "Top 3 practitioner_code by fresh_specialty_patients across months", "Lindia profiswap postal_code across healthcare_provider_id", "Top 3 first_brand_rx_number by doc_id across first_brand_rx_number and months", "Most profidomain initial_prescription_volume across months", "Lindia profioverall_rx initial_prescription_volume across initial_prescription_volume", "Bottom 3 initial_prescription_volume by sales_personnel", "Bottom 3 complete_rx_count by doctor_name across Brownsville", "Bottom 3 months by first_brand_rx_number", "Bottom 3 recent_specialty_count by shift across months", "Bottom 3 entire_rx_number by doctor_identifier across Physician ID and months", "top 3 total_prescriptions basis  growth in rate", "top 5 contibuting Brand to rate", "top 3 initial_prescription_volume contributing to growth of rate in rolling 3", "when was the first time complete_rx_count of NRx was more than 40K", "when was the last time fresh_rx of alteration was less than 40K", "When was the NBRx of total_prescriptionss was more than 600", "in which month healthcare_provider_id of province in india was more than 1.5K", "In which quarter practitioner_code of swap in india was above 2.5 k", "In which year commerce of change was more than 450 K", "when was the full_prescription_volume of first_specialty_number was less than 10K first time", "when was the swap of treatment_field was more than 35K last time", "when was the first time Physician ID of complete_rx_count was more than 20K in 2019", "In which month first_brand_rx_number of specialization in india was more than 800 in 2018", "In which quarter full_prescription_volume of change in india was above 2.5 k in 2019", "when was the last time trend of City of new_prescriptions was more than 5K in last 3 months", "when was the first time trend of region of npi_key was less than 10K in last 3 months", "correlation between municipality and medical_practitioner_id", "correlation between ims_key and NTS in india", "correlation between zone and Physician Name in india and dubai", "correlation between ims_identifier of ims_identifier and place", "What is the correlation between municipality volume of all new_to_specialty", "What is the division and division correlation across Physician Name", "What is the correlation of transition and Payer across years", "Correlation between medical_practitioner_id and region in medical_field and india across Switch", "Correlation between doc_id and change in india and india across NBRx", "correlation between complete_rx_count and Territory  across financial_institution", "Correlation between alteration and NRx  across doctor_name exclude US", "Correlation between area and reimbursement_channel in india and india across territory", "medical_practitioner in rolling 3 months", "which District has most specialization in rolling 5 months", "top 3 total_prescriptions by Therapy Area in rolling 7 months in canada", "new_prescriptions of entire_rx_number in rolling 13 months across marque", "Which expertise has the most percentage full_prescription_volume in rolling 10 months", "when was the trend of Switch of industry was more than 40K in rolling 6 months", "when was the first time trend of Sales Force of medical_field was more than 40K in rolling 6 months", "municipality of locality in last 10 weeks", "coverage_source by NTS  in last 2 weeks", "which geographical_area has most trade in india in last 20 weeks", "ims_identifier of all fresh_specialty_patients in basic in last 35 weeks", "zipcode of ims_key in Canada in last 13 weeks", "What is the Cardiovascular Diseases of fresh_specialty_patients geographical_area", "place of doctor_identifier in last month", "What is the NBRx of initial_brand_prescription_volume in india zone", "what is the doctor_identifier of doctor_identifier fresh_brand_rx", "which state has most intervention_zone in this year", "industry per State in this year", "which product has the most fresh_brand_rx in this year", "expertise_group by doc_id in this year", "doc_id in this year", "transition by product more than 500 in US in this year", "City of recent_brand_rx_count in this quarter", "Brand of entire_rx_number in india zone in current quarter", "overall_rx in previous quarter", "town by NBRx in quarter2", "which care_area has highest marque in q3 2018", "top 3 healing_sector by metropolis  in q-4 2017", "bottom 2 TRx by full_prescription_volume  in q 1 2019", "which entire_rx_number has healthcare_provider in india  above 2.5 k in quarter 4 2019", "practitioner by product in quarter-1 2017", "doc_id across shift in present quarter", "what will be the npi_key of make  in q1 2020", "mailing_code in q1 2020 vs q2 2020", "healthcare_area_group of specialization_group in this month vs last month", "State by Payer in last 2 months", "new_prescriptions across sales_team in this month", "which NBRx has most percentage practice_focus_group in last 1 month", "ims_code in this month", "product by percentage complete_rx_count in last 5 months", "Top 3 sales_team by shift in last 13 months", "Forecast of healthcare_provider_id intervention_zone", "forecast of urban_area for 2021", "Forecast of entire_rx_number of financial_route in next 3 months", "Forecast of fresh_brand_rx for the next quarter", "What would be the initial_prescription_volume of territory in 2022", "what will be the trademark in next 3 months", "what will the overall_rx in next 3 months", "Address in next 2 quarters", "expertise in next year", "Zip_Code of NBRx", "fresh_specialty_patients of new_prescriptions", "kda of insurance_type in india", "kda of sales_representatives in india", "why Payer of zipcode changed", "why is commerce intervention_zone changing", "why is my sector changing in p3m", "diagnose the increase in shift zipcode in last month", "why has the NBRx of sales_staff increased in 2019 over 2018", "Sales Force in Q1 2020 vs Q1 2019"]

In [20]:
%%time 
data_fin = []
list_1 =["sales"]
for user_query in list_1:
    print('user query: ', user_query)
    print('-'*100)
    output, raw = predict_template_query_v1(user_query=user_query)
    print(eval(output))
    print('-'*100)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    print('Step 1:' +raw.split('\nStep 1:')[1])
    print('-'*100)
    data_fin.append([user_query,eval(output)])
import csv
with open('data_pharma.csv', 'a', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Query", "Intermediate MQL"])

# Write data iteratively
    for row in data_fin[0:]:
        csvwriter.writerow(row)

user query:  sales
----------------------------------------------------------------------------------------------------


OutOfMemoryError: CUDA out of memory. Tried to allocate 164.00 MiB (GPU 0; 15.60 GiB total capacity; 14.54 GiB already allocated; 54.94 MiB free; 14.87 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [17]:
user_query = "sales in jan 2020 versus year ago"
print('user query: ', user_query)
print('-'*100)
output, raw = predict_template_query_v1(user_query=user_query)
print(eval(outut))print('-'*100)
steps = 'Step 1:' +raw.split('\nStep 1:')[1]
print('Step 1:' +raw.split('\nStep 1:')[1])
print('-'*100)

SyntaxError: invalid syntax (<ipython-input-17-c3698c5997bd>, line 5)

In [None]:
#model.to("cuda")