### Distill step by step finetuning approach - trying enhanced rationale with specific reasoning for date conversion

In [1]:
!sudo pip install -q transformers --upgrade
!sudo pip install -q peft

In [2]:
import transformers
transformers.__version__

'4.35.0'

In [3]:
import os
import torch
from datasets import load_dataset
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
import pandas as pd
import torch

In [4]:
# The model that you want to train from the Hugging Face hub
model_name = "mistralai/Mistral-7B-Instruct-v0.1"

In [5]:
context = """{
    "MEASURE": [{"ENTITY": "Discount", "other names": ["discount", "discount rate", "discount value", "deduction"]},
                {"ENTITY": "Purchase Vol", "other names": ["purchase", "purchase value", "purchase model"]},
                {"ENTITY": "Quantity", "other names": ["quantity", "volume"]},
                {"ENTITY": "Sales", "other names": ["sales", "sale"]}],
    "DIMENSION": [{"ENTITY": "Sub-Category", "other names": ["sub-category", "sub category", "categories", "section"]},
                  {"ENTITY": "Segment", "other names": ["segment", "segments", "units", "divisions"]},
                  {"ENTITY": "Parts", "other names": ["parts", "part", "section", "divisions"]},
                  {"ENTITY": "Country", "other names": ["country", "countries"]}],
    "FILTER": [{"ENTITY": "Consumer", "other names": ["consumers", "consumer"], "parent": "Segment"},
               {"ENTITY": "Phone", "other names": ["phone", "phones", "mobile phones"], "parent": "Sub-Category"},
               {"ENTITY": "Binder", "other names": ["binders", "binder"], "parent": "Sub-Category"},
               {"ENTITY": "Corporate", "other names": ["corporates", "corporate"], "parent": "Segment"},
               {"ENTITY": "India", "other names": ["india"], "parent": "Country"},
               {"ENTITY": "Dubai", "other names": ["dubai"], "parent": "Country"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Order Date", "other names": ["order date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""

In [None]:
context = """{
    "MEASURE": [{"ENTITY": "Insurance covergae", "other names": ["insurance amount", "total insurance coverage", "coverage"]},
                {"ENTITY": "Hospital bill", "other names": ["bill", "hospital expenses", "expenses"]},
                {"ENTITY": "Count", "other names": ["quantity", "counts"]}],
    "DIMENSION": [{"ENTITY": "Disease", "other names": ["disease", "Diseases", "health issues"]},
                  {"ENTITY": "State", "other names": ["segment", "segments", "units", "divisions"]},
                  {"ENTITY": "Insurer", "other names": ["insurer", "insurance provider"]}],
    "FILTER": [{"ENTITY": "Covid", "other names": ["covid-19", "covid19","Covid 19"], "parent": "Disease"},
               {"ENTITY": "Cancer", "other names": ["cancers", "cancer", "tumour"], "parent": "Disease"},
               {"ENTITY": "Delhi", "other names": ["New Delhi", "delhi"], "parent": "State"},
               {"ENTITY": "Maharashtra", "other names": ["corporates", "corporate"], "parent": "State"},
               {"ENTITY": "HDFC ergo", "other names": ["hdfc","HDFC","HDFC health insurance","hdfc insurance], "parent": "Insurer"},
               {"ENTITY": "Aditya Birla", "other names": ["aditya birla health insurance","aditya birla insurance"], "parent": "Insurer"}],
    "DERIVED MEASURE": [{"ENTITY": "Ratio",
             "other names": ["ratio", "share", "contribution", "percentage", "proportion", "contributing"]},
            {"ENTITY": "Why", "other names": ["why", "cause of", "reason for", "diagnose"]},
            {"ENTITY": "contribution_to_growth", "other names": ["contribution to growth", "growth", "grown"]},
            {"ENTITY": "kda_transactional", "other names": ["kda", "key drivers", "key driver", "drivers", "driver"]},
            {"ENTITY": "Growth Rate", "other names": ["growth rate", "growth", "grown"]},
            {"ENTITY": "correlation",
             "other names": ["associate", "associated", "association", "associations", "correlate", "correlated",
                             "correlation", "correlations", "relate", "related", "relation", "relations",
                             "relationship",
                             "relationships"]}
            ],
    "DATE VARIABLE": [{"ENTITY": "Admit Date", "other names": ["admit date", "date", "trend", "time", "when", "mom", "yoy"]}]
    }"""

In [6]:
date_input = {
    "start_date": "01/01/2020",
    "end_date": "15/09/2023"
}

In [7]:
torch.cuda.is_available()

False

In [8]:
# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True,
                                          # add_eos_token=True,
                                          use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [9]:
from peft import PeftModel, PeftConfig

In [10]:
new_model_name = "/data/mistral/query-to-mql/exp-9/nov-01/checkpoint-4000"

In [11]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(new_model_name, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
query_template_v1 = """Given the context : {context} and date reference: {date_input}, the query: {user_query}, is converted into below shown structured output.
[MQL]
"""

In [14]:
def predict_template_query_v1(user_query):
    inp = query_template_v1.format(context=context,
                                   user_query=user_query,
                                  date_input=date_input)
    _inputs = tokenizer.encode(inp, return_tensors="pt")
    outputs = model.generate(input_ids=_inputs, max_length= 1700, pad_token_id=tokenizer.eos_token_id)
    output = tokenizer.decode(outputs[0])
    output_new = output.split('[MQL]\n')[1]
    return output_new.split('\n[/MQL]')[0], output
#     return output

In [15]:
%%time
user_query = 'why sales changed in last 2 weeks of Nov 2021'
print('user query: ', user_query)
print('-'*100)
output, raw = predict_template_query_v1(user_query=user_query)
print(eval(output))
print('-'*100)
print('Step 1:' +raw.split('\nStep 1:')[1])

user query:  why sales changed in last 2 weeks of Nov 2021
----------------------------------------------------------------------------------------------------
{'DATE VARIABLE': {'last 2 weeks of Nov 2021': [{'CONVERTED TIME ELEMENT': 'last 2 weeks of Nov 2021', 'DATE RANGE': '2021/11/08 - 2021/11/21', 'ENTITY': 'Order Date'}]}, 'DERIVED MEASURE': {'why': [{'ENTITY': 'Why'}]}, 'MEASURE': {'sales': [{'ENTITY': 'Sales'}]}}
----------------------------------------------------------------------------------------------------
Step 1: Identify the components in the query
- The query asks for the reason behind the change in sales.
- The time period mentioned is the last 2 weeks of Nov 2021.

Step 2: Match the components to the context
- The measure "sales" can be matched to the "Sales" entity in the context.
- The derived measure "why" can be matched to the "Why" entity in the context.
- The date component "last 2 weeks of Nov 2021" can be matched to the "Order Date" entity in the context.

St

In [14]:
def inference(user_query):
    output, raw = predict_template_query_v1(user_query=user_query)
    mql = eval(output)
    steps = 'Step 1:' +raw.split('\nStep 1:')[1]
    return mql, steps

In [None]:
%%time
user_query = 'why sales changed in last 2 weeks of Nov 2021'
inference(user_query)