In [118]:
from mcli import predict
import pandas as pd
import json
import re
from transformers import AutoTokenizer

In [119]:
## convert data label into list of list
# test_data = pd.read_csv("/Users/rui.chen@avalara.com/Downloads/test.tsv", sep="\t")
# test_data["output"] = test_data["output"].map(json.loads)
# test_data["output"] = test_data["output"].map(lambda x: [list(obj.values()) for obj in x])
# test_data.to_csv("test.tsv", sep="\t")

In [120]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf")
finetuned_host = "https://llama2-13b-sut-nested-list-fhvj7c.inf.hosted-on.mosaicml.hosting"


In [133]:
test_data = pd.read_csv("test-list.tsv", sep="\t")

test_data.sample(3)

Unnamed: 0,document_url,input,output
89,Hotel-Motel-Tax-Increase_Hot-Topics-13ceff.pdf,"\nCITY AND BOROUGH OF JUNEAU, ALASKA\nFinance ...","[[""JUNEAU"", ""Hotel/Motel tax"", ""9.0"", ""2019-10..."
2,2020_Apr_GraysHarbor-bf4ccd.pdf,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""Unincorp. Areas"", ""Local Sales Tax Rate"", ""..."
16,2021_July_UnincorpPierceCoNonRTALodging-2fbb00...,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""Nisqually Tribe \u2013 Unincorporated Pierc..."


## get maximum token count from test data

In [122]:
max(len(tokenizer.encode(text)) for text in test_data.input.tolist())

3601

In [123]:
len(tokenizer.encode("hello world"))

3

## try a sample

In [126]:
_prediction = predict(finetuned_host, 
                                    {
                                        "inputs": ["Seattle lodging tax will increase the value to 2% by Sep 1, 2023"], 
                                        "temperature": 0.,
                                        "max_length": 50,
                                    }, 
                                    timeout=2000)
_prediction

{'outputs': ['. [["Seattle", "Lodging Tax", "2", "Sep 1, 2023", "nan"]]']}

## get predictions

In [128]:
def get_prediction(text, tokenizer, finetuned_host):
    current_token_count = len(tokenizer.encode(text))
    
    try:
        i_prediction = predict(
            finetuned_host, 
            {
                "inputs": [text], 
                "temperature": 0.,
                "max_length": 4096 - current_token_count,
            }, 
            timeout=6000
        )
        return i_prediction
    except Exception as e:
        print(f"Prediction failed for text: {text}. Reason: {e}")
        return None


predictions_finetuned = []
for i, row in test_data.iterrows():
    prefix = "generate templates: "
    suffix = ""
    inputs = prefix + row["input"] + suffix
    
    i_prediction = get_prediction(inputs, tokenizer, finetuned_host)
    
    if i_prediction:
        print(i)
        # print(i_prediction["outputs"][0])
        predictions_finetuned.append(i_prediction["outputs"][0])
    else:
        print(f"{i} failed")
        predictions_finetuned.append("[[]]")



len(predictions_finetuned)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
Prediction failed for text: generate templates: Acadia Parish. Effective 7/1/2018: **The rate table below includes the Louisiana State Rate decrease effective July 1, 2018.**. 
Jurisdiction  Name,Col,School  Brd,Police  Jury,Law  Enforcement,City  Town,Comb  Local  Rate,State  Rate,Total  Rate
Crowley,A,1.50%,1.00%,0.50%,2.50%,5.50%,4.45%,9.95%
Church Point,B,1.50%,1.25%,0.50%,2.00%,5.25%,4.45%,9.70%
Iota,C,1.50%,1.25%,0.50%,2.00%,5.25%,4.45%,9.70%
Estherwood,D,1.50%,1.25%,0.50%,1.00%,4.25%,4.45%,8.70%
Mermentau,E,1.50%,1.25%,0.50%,1.00%,4.25%,4.45%,8.70%
Rayne,F,1.50%,1.00%,0.50%,2.00%,5.00%,4.45%,9.45%
Morse,G,1.50%,1.25%,0.50%,1.00%,4.25%,4.45%,8.70%
Eunice city limits in Acadia Parish  (1),H,1.50%,1.25%,0.50%,–,3.25%  (1),4.45%,7.70%
Duson city limits in Acadia Parish  (2),H,1.50%,1.25%,0.50%,–,3.25%  (2),4.45%,7.70%
Basile city limits in Acadia Parish 

100

In [136]:
len(predictions_finetuned)

100

In [130]:
import pickle
with open('raw-llama2-13b-nested-list.pkl', 'wb') as f:
    pickle.dump(predictions_finetuned, f)

## post processing

In [138]:
todo_prediction = [pred for pred in predictions_finetuned if not (pred.endswith("}]") or pred.endswith("]]"))]
todo_prediction

['[["Unincorp. Areas", "Local Sales Tax Rate", "0.029", "2021-07-01 00:00:00", "nan"], ["Unincorp. Areas", "State Sales Tax Rate", "0.065", "2021-07-01 00:00:00", "nan"], ["Unincorp. Areas", "Total Sales Tax Rate", "0.094", "2021-07-01 00:00:00", "nan"], ["Unincorp. Non-RTA", "Local Sales Tax Rate", "0.015", "2021-07-01 00:00:00", "nan"], ["Unincorp. Non-RTA", "State Sales Tax Rate", "0.065", "2021-07-01 00:00:00", "nan"], ["Unincorp. Non-RTA", "Total Sales Tax Rate", "0.08", "2021-07-01 00:00:00", "nan"], ["Unincorp. Non-RTA HBZ", "Local Sales Tax Rate", "0.015", "2021-07-01 00:00:00", "nan"], ["Unincorp. Non-RTA HBZ", "State Sales Tax Rate", "0.065", "2021-07-01 00:00:00", "nan"], ["Unincorp. Non-RTA HBZ", "Total Sales Tax Rate", "0.08", "2021-07-01 00:00:00", "nan"], ["Unincorp. PTBA", "Local Sales Tax Rate", "0.035", "2021-07-01 00:00:00", "nan"], ["Unincorp. PTBA", "State Sales Tax Rate", "0.065", "2021-07-01 00:00:00", "nan"], ["Unincorp. PTBA", "Total Sales Tax Rate", "0.1", "20

In [131]:
def get_labels(i, text):
    # nested list
    res = "[[" + text.split("[[")[-1]
    if res.endswith("]]"):
        return res
    else:
        return res + "']]"

    return text



def clean_json_list(data_str):
    """
    Given a string representation of a list of JSON objects, 
    this function filters out any incomplete JSON objects 
    and returns a cleaned up string representation.
    """
    # Split the string into potential JSON objects
    data_list = data_str.strip().lstrip('[').rstrip(']').split('},')
    
    # Nested function to check if a string can be parsed as JSON
    def is_valid_json(jstr):
        # Try parsing as-is
        try:
            json.loads(jstr)
            return True
        except json.JSONDecodeError:
            pass
        
        # Try appending a '}' to make it valid
        try:
            json.loads(jstr + "}")
            return True
        except json.JSONDecodeError:
            return False

    # Filter valid JSON objects and collect them
    valid_data = []
    for item in data_list:
        if is_valid_json(item):
            try:
                valid_data.append(json.loads(item))
            except json.JSONDecodeError:
                valid_data.append(json.loads(item + "}"))

    # Convert list of valid JSON objects back to a string representation
    return json.dumps(valid_data)  # indent for pretty printing


    
# predictions = [clean_json_list(item) if item != '[[]]' else """[[]]""" for i, item in enumerate(predictions_finetuned) ]


    
predictions = [get_labels(i, item) for i, item in enumerate(predictions_finetuned)]

In [132]:
todo_prediction = [pred for pred in predictions if not (pred.endswith("}]") or pred.endswith("]]"))]
len(todo_prediction)

0

In [115]:
len(predictions_finetuned)

100

In [134]:
test_data["prediction"] = predictions
test_data

Unnamed: 0,document_url,input,output,prediction
0,2020_Apr_George-740d04.pdf,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""George TBD"", ""Local Sales Tax Rate"", ""0.019...","[[""Unincorp. Areas"", ""Local Sales Tax Rate"", ""..."
1,2020_Apr_GigHarbor-92d49d.pdf,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""Gig Harbor"", ""Local Sales Tax Rate"", ""0.022...","[[00:00:00"", ""nan""], [""Unincorp. Areas PTBA HB..."
2,2020_Apr_GraysHarbor-bf4ccd.pdf,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""Unincorp. Areas"", ""Local Sales Tax Rate"", ""...","[[""Unincorp. Areas"", ""Local Sales Tax Rate"", ""..."
3,2020_Apr_GraysHarborLodging-d073f7.pdf,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""Unincorp. Areas"", ""Combined Sales Tax Rate""...","[[""Unincorp. Areas"", ""Combined Sales Tax Rate""..."
4,2020_Apr_Yakima-561d81.pdf,WASHINGTON STATE DEPARTMENT OF REVENUE. LOCAL ...,"[[""Unincorp. Areas"", ""Local Sales Tax Rate"", ""...","[[""Unincorp. Areas"", ""Local Sales Tax Rate"", ""..."
...,...,...,...,...
95,local-tax-changes-10-01-2021-dd4828.pdf,"\n,Local Tax Changes\n,Effective October 1, 20...","[[""Hankinson"", ""City sales, use, and gross rec...","[[""Hankinson"", ""city sales, use, and gross rec..."
96,local-tax-changes-7-1-2020-b76cfe.pdf,"Effective July 1, 2020. \nThe following local ...","[[""Beulah"", ""city sales, use, and alcohol gros...","[[""Beulah"", ""city sales, use, and gross receip..."
97,tb-01-21-7a5d19.pdf,"\n,Changes in Utah laws or Tax\n,Commission ru...","[[""Orderville (Kane County)"", ""Transient Room ...","[[""Orderville (Kane County)"", ""Transient Room ..."
98,tb-02-22-a5b2d2.pdf,"\n,Changes in Utah laws or Tax\n,Commission ru...","[[""Eagle Mountain (Utah County)"", ""Transient R...","[[""Rich County"", ""Transient Room Tax"", ""4.25"",..."


In [135]:
test_data.to_csv("llama2-13b-list.tsv", sep="\t")

In [139]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-hf")


In [140]:
tokenizer

LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-13b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False)}, clean_up_tokenization_spaces=False)