Keep an eye on the progress in notion 

@October 14 https://www.notion.so/adrianmohnacs/3dcf69ad6da44496ab9f890044158553?pvs=4

@October 15

## Pipeline Steps:

- Depenencies and file loading
- Getting familiar with data
- Data preparation. This involves mapping the entity type to fields and combining the user query with the field description. (field_Description provides the semantic context to get the correct json prop)
- JSON cleaning and error handling
- We then label the data using the json entities extracted to provide more context to the data.
- Map dictionairies to dataset (this will change)
- Training / validation split
- Saving dataset locally
- Loading or training the a new model
- Save the model and tokenizer locally
- Evaluation
- Inference

##  Merged data set for processing
**To train we remove the json and field_name using those for labeling**
| entity_name | json | field_name | field_type | description |
|-------------|------|------------|------------|--------------------------------------------------------------|
| CDR         | {'entityType': 'CDR', 'statements': [{'type': 'technology', 'value': '3G'}]} | ifc.ootb.CDR.callStatus | string | Status of the call: "Successful", "Failed", "Busy", etc. |
| CDR         | {'entityType': 'Web Activity', 'statements': [{'type': 'platform', 'value': 'Reddit'}, {'type': 'time', 'value': 'yesterday'}, {'type': 'keyword', 'value': 'funny'}]} | ifc.CDR.caseCode | string | Unique code identifying a specific case |
| CDR         | {'entityType': 'Investigation', 'statements': [{'type': 'status', 'value': ['open', 'closed']}]} | ifc.CDR.chatTopic | string | Topic or subject of discussion in the chat |
| CDR         | {'entityType': 'Insight', 'statements': [{'type': 'relatedTo', 'value': 'Jane Doe'}]} | ifc.ootb.CDR.createDateTime | date | Date and time of record creation. |
| CDR         | {'entityType': 'Web Activity', 'statements': [{'type': 'time', 'value': 'last day'}]} | ifc.ootb.CDR.direction | string | Direction of the call (incoming, outgoing) |

## Model Selection
#### The winner is Albert

| Model Variant |	Number of Parameters |	Model Size on Disk
|---------------|---------------------|---------------------|
albert-base-v2 |	11M |	~46 MB |	12	|

**We then move to training**
By combining the query with field descriptions, the model can better understand the semantic meaning of the entities involved, improving its ability to map queries to the correct JSON labels.

Input for training Example:

```
Query: Find all calls made using 3G technology. 
Entity (Label): CDR. 
Fields: callStatus: Status of the call; createDateTime: Date and time of record creation.
```
Label Example:
```
"CDR"
```
If there is a relation target:
```
"CDR|Phone"
```

![SegmentLocal](B_DSKq.gif "segment")
#### Not helping?

### TLDR

We were able to get a local model pipeline in place with evaluation and inference. Prompt engineering is not needed.  Much simpler implementation than LangChain, vector stores, or RAG.

This is a good start. We get some great results but the size of the dataset risks overfitting.  

Though I will need another day of deep work to get the pipeline to production quality and A LOT OF TESTING.  This will also allow me to add some niceties to the notebook so you can just "plug and play".

To improve our output I'd like to get clear on what quality outputs look like and discuss the data architecture and the relations between all the features. I'd also like to see more example outputs to get a better sense of the model's desired behavior.

### Questions

- Let's clearly break down the relation between the user query and the fields.  I want to hear it entirely from your perspective?
- We want to predict BOTH the entity type and the relation target type?
- What do you consider a good output here for the prediction?

### What's left to do?

- Test cases aligned with the examples you'd like to see
- Clean up the notebook and add comments
- Look into ways to get a bit more realistic accuracy of the model
- Allow for a new CSV to be uploaded in markdown Ui in notebook and assigned to a path variable that is processed in data step
- Simplify model training configuration make sure everything is optimized to be run on the local machine and using steps, not a mix of steps and epochs.
- Optimized local storage and defensive code to conserve resources and make sure we save and load properly
- TEST CASES


In [26]:
# !python -m venv env
# !source env/bin/activate  
!pip install torch transformers scikit-learn pandas sentencepiece
!pip install numpy==1.26.4 --force-reinstall
# needed for training
! pip install -U accelerate

import numpy as np
print(np.__version__)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl (20.6 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
Successfully installed numpy-1.26.4

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;

### Sample Case for Getting Familiar with Data and Model

In [None]:
# Sample JSON-like data (you'll replace this with your CSV data)
json_data = [
    {"entityType": "CDR", "relationTargetType": "Phone"},
    {"entityType": "Report", "relationTargetType": "Malware"}
]

# Example query from the user
query = "What SMS messages were sent from suspicious phones to 0549876543 containing 'urgent'?"


In [None]:
#Function to search for relevant entitties in teh JSON data
def find_matching_entities(query, json_data):
    matching_entities = []

    for record in json_data:
        entity_text = f"{record['entityType']} {record['relationTargetType']}"

        #encode inputs for model
        inputs = tokenizer(query, entity_text, return_tensors="pt")

        #run the model to get answer scores
        with torch.no_grad():
            outputs = model(**inputs)

        #get the start and end scores for the answer
        answer_start = torch.argmax(outputs.start_logits)
        answer_end = torch.argmax(outputs.end_logits) + 1

        #extract the answer
        predicted_entity = tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])
        )

        #if predicted entity is not empty, consider it a match
        if predicted_entity.strip():
            return record['entityType'], record['relationTargetType']

    return None
    # return list(set(matching_entities)) #remove duplicates

In [None]:
#Test example
matching_entities = list(set(find_matching_entities(query, json_data)))
print(matching_entities)

### Data Exploration


In [15]:
# load data and create mapping into new dataframe
# right now we are just using the user query 
### TODO files are static paths now.  we need to make them dynamic and maybe add a nice UI to select the file
import pandas as pd
import json

fields_desc = pd.read_csv('fields_description.csv')
user_queries = pd.read_csv('user_queries.csv')

print('fields_desc')
fields_desc.head()



fields_desc


Unnamed: 0,entity_name,field_name,field_type,description
0,CDR,ifc.ootb.CDR.callStatus,string,"Status of the call: ""Successful"", ""Failed"", ""B..."
1,CDR,ifc.CDR.caseCode,string,Unique code identifying a specific case
2,CDR,ifc.CDR.chatTopic,string,Topic or subject of discussion in the chat
3,CDR,ifc.ootb.CDR.createDateTime,date,Date and time of record creation.
4,CDR,ifc.ootb.CDR.direction,string,"Direction of the call (incoming, outgoing)"


In [16]:

print('\nuser_queries')
user_queries.head()


user_queries


Unnamed: 0,question,json
0,Find all calls made using 3G technology.,"{'entityType': 'CDR', 'statements': [{'type': ..."
1,List all Reddit comments posted yesterday with...,"{'entityType': 'Web Activity', 'statements': [..."
2,Show me investigations that are either open or...,"{'entityType': 'Investigation', 'statements': ..."
3,Find all insights related to the witness Jane ...,"{'entityType': 'Insight', 'statements': [{'typ..."
4,List all web activities updated in the last da...,"{'entityType': 'Web Activity', 'statements': [..."


In [13]:
fields_desc.describe(include="all")

Unnamed: 0,question,json
count,744,744
unique,742,721
top,Show me insights where the text includes 'witn...,"{'entityType': 'Phone', 'statements': [{'type'..."
freq,2,3


In [14]:

user_queries.describe(include="all")

Unnamed: 0,question,json
count,744,744
unique,742,721
top,Show me insights where the text includes 'witn...,"{'entityType': 'Phone', 'statements': [{'type'..."
freq,2,3


In [17]:
print('value distribution for field entities')
print(fields_desc['entity_name'].value_counts())


 value distribution
entity_name
CDR              36
EVisa Request    29
Web Activity     25
Web Actor        23
Phone            15
Person            6
Investigation     5
Report            3
Insight           2
Name: count, dtype: int64


In [28]:
# Create a dictionary mapping field_name → description for quick lookup.
# let's just look at this like a dictionary
field_mapping = fields_desc.set_index('field_name')['description'].to_dict()
print(field_mapping)


{'ifc.ootb.CDR.callStatus': 'Status of the call: "Successful", "Failed", "Blocked", or "Redirected"', 'ifc.CDR.caseCode': 'Unique code identifying a specific case', 'ifc.CDR.chatTopic': 'Topic or subject of discussion in the chat', 'ifc.ootb.CDR.createDateTime': 'Date and time of record creation.', 'ifc.ootb.CDR.direction': 'Direction of the call (incoming, outgoing)', 'ifc.ootb.CDR.duration': 'Duration of the communication in minutes. You can ask it For example: 1min -> 60', 'ifc.CDR.emailSubject': 'Subject of the email communication', 'ifc.ootb.CDR.endTime': 'Time when the communication ended', 'ifc.ootb.CDR.hasContent': 'Indicates if the communication has content', 'ifc.ootb.CDR.imei': 'this field is intended to store the IMEI number of the device who made a call.', 'ifc.ootb.CDR.imei2': "this field is intended to store the IMEI number of the device that is receiving the call. It serves the same purposes as the caller's IMEI but for the receiving side of the communication.", 'ifc.oo

In [34]:
# prepare our data for training. we combine our user query with field description
import json
import re

def clean_json_string(json_string):
    # Remove any leading/trailing whitespace
    json_string = json_string.strip()
    
    # Ensure the string is enclosed in curly braces
    if not json_string.startswith('{'):
        json_string = '{' + json_string
    if not json_string.endswith('}'):
        json_string = json_string + '}'
    
    # Replace single quotes with double quotes, but not within values
    json_string = re.sub(r"(?<!\\)'", '"', json_string)
    
    # Remove any trailing commas before closing braces or brackets
    json_string = re.sub(r',\s*([\]}])', r'\1', json_string)
    
    return json.loads(json_string)


In [42]:
def find_matching_json_for_field_name(field_name, user_queries):
    """
    Finds the first user query where the JSON contains the given field_name
    as a parameter in one of its filter statements.
    
    Parameters:
    - field_name: The field name to search for.
    - user_queries: DataFrame containing the user queries and their JSON objects.

    Returns:
    - matching_row: The row from user_queries where the field_name is found.
    """
    for _, row in user_queries.iterrows():
        try:
            json_data = json.loads(clean_json_string(row['json']))
        except json.JSONDecodeError as e:
            # print(f"Error decoding JSON for query: {row['question']}")
            # print(f"Error: {e}")
            json_data = {}

        # Check if any filter statement in the JSON contains the field_name as a parameter
        for statement in json_data.get('statements', []):
            param_name = statement['parameters'].get('name', '')
            if param_name == field_name:
                return row  # Return the row if a match is found

    return None  # Return None if no matching JSON is found

print(find_matching_json_for_field_name('ifc.ootb.CDR.imei', user_queries))

question    Show me all non-voice communications from IMEI...
json        {'entityType': 'CDR', 'statements': [{'type': ...
Name: 30, dtype: object


In [29]:
for i in range(3):
    print(f"Input {i+1}: {inputs[i]}")
    print(f"Label {i+1}: {labels[i]}")

Input 1: Query: Find all calls made using 3G technology.. Entity: CDR. Fields: ifc.ootb.CDR.callStatus: Status of the call: "Successful", "Failed", "Blocked", or "Redirected";ifc.CDR.caseCode: Unique code identifying a specific case;ifc.CDR.chatTopic: Topic or subject of discussion in the chat;ifc.ootb.CDR.createDateTime: Date and time of record creation.;ifc.ootb.CDR.direction: Direction of the call (incoming, outgoing);ifc.ootb.CDR.duration: Duration of the communication in minutes. You can ask it For example: 1min -> 60;ifc.CDR.emailSubject: Subject of the email communication;ifc.ootb.CDR.endTime: Time when the communication ended;ifc.ootb.CDR.hasContent: Indicates if the communication has content;ifc.ootb.CDR.imei: this field is intended to store the IMEI number of the device who made a call.;ifc.ootb.CDR.imei2: this field is intended to store the IMEI number of the device that is receiving the call. It serves the same purposes as the caller's IMEI but for the receiving side of the

In [30]:
# train test split
from sklearn.model_selection import train_test_split

train_inputs, val_inputs, train_labels, val_labels = train_test_split(
    inputs, labels, test_size=0.2, stratify=labels, random_state=42
)

print(f"Training size: {len(train_inputs)}, Validation size: {len(val_inputs)}")

Training size: 553, Validation size: 139


In [31]:
from transformers import AlbertTokenizer
import torch

model_name = "twmkn9/albert-base-v2-squad2"
tokenizer = AlbertTokenizer.from_pretrained(model_name)
train_encodings = tokenizer(train_inputs, truncation=True, padding=True, return_tensors="pt")
val_encodings = tokenizer(val_inputs, truncation=True, padding=True, return_tensors="pt")

# labels to tensors. we can't have names only machine values
# remove duplicates and iterate through to assign a number
unique_labels = list(set(labels))
train_labels_tensor = torch.tensor([unique_labels.index(lbl) for lbl in train_labels])
val_labels_tensor = torch.tensor([unique_labels.index(lbl) for lbl in val_labels])


In [32]:
# do our dimensions match?
print(f"Training encodings: {train_encodings['input_ids'].shape}, Labels: {train_labels_tensor.shape}")
print(f"Validation encodings: {val_encodings['input_ids'].shape}, Labels: {val_labels_tensor.shape}")

Training encodings: torch.Size([553, 512]), Labels: torch.Size([553])
Validation encodings: torch.Size([139, 512]), Labels: torch.Size([139])


In [33]:
from torch.utils.data import Dataset

class EntityDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }
    


In [34]:
# Create dataset objects for training and validation
train_dataset = EntityDataset(train_encodings, train_labels_tensor)
val_dataset = EntityDataset(val_encodings, val_labels_tensor)

# check the first sample from the training dataset
print(train_dataset[0])


{'input_ids': tensor([    2, 25597,    45,   298,    55,    65,  3029,    29,    21,  9403,
           16,   712,   902,     9,     9,  9252,    45,  1745,   139,     9,
         2861,    45,   100,   150,     9,  4328, 11872,     9,   150,  3807,
            9,  9200, 10631,   267,    45,  1782,    16,    14,   645,    45,
           13,     7, 29245,  1566,     7,    15,    13,     7, 24910,    69,
            7,    15,    13,     7, 12048,    69,     7,    15,    54,    13,
            7,    99, 14147,     7,    73,   821,   150,     9,   150,  3807,
            9, 10325,  9375,    45,  2619,  1797, 13785,    21,  1903,   610,
           73,   821,   150,     9,   150,  3807,     9, 13409,  3880,   596,
           45,  8303,    54,  1550,    16,  5460,    19,    14,  6615,    73,
          821,   150,     9,  4328, 11872,     9,   150,  3807,     9, 18475,
         1373,   891,    45,  1231,    17,    85,    16,   571,  2502,     9,
           73,   821,   150,     9,  4328, 11872, 

In [35]:
# prepare results folder
import os
from datetime import datetime

# Create a unique output directory
base_output_dir = "./results"
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(base_output_dir, f"run_{current_time}")

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

In [39]:
# model and training run
# check for local first
from transformers import AlbertForSequenceClassification, AlbertTokenizer
from transformers import TrainingArguments, Trainer
import pickle
import torch
import os

# output_dir = "./results/run_20241014_163059"

training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    gradient_accumulation_steps=4,
    save_total_limit=2,
    load_best_model_at_end=True,
    weight_decay=0.01,
    no_cuda=True,
)

# Assuming Training is already done and we have models locally or configuration vars
if os.path.exists(output_dir) and os.path.exists(output_dir + "/spiece.model"):
    print(f"Loading model from {output_dir}")
    model = AlbertForSequenceClassification.from_pretrained(output_dir)
    tokenizer = AlbertTokenizer.from_pretrained(output_dir)

    print(f"Model loaded from {output_dir}")
    print(f"Model: {model}")

    # Reload training dataset
    with open("train_dataset.pkl", "rb") as f:
        train_dataset = pickle.load(f)

    # Reload validation dataset
    with open("val_dataset.pkl", "rb") as f:
        val_dataset = pickle.load(f)

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
    )
else:
    print('saving training dataset')
    # Save training dataset
    with open(f"{output_dir}/train_dataset.pkl", "wb") as f:
        pickle.dump(train_dataset, f)

    # Save validation dataset
    with open(f"{output_dir}/val_dataset.pkl", "wb") as f:
        pickle.dump(val_dataset, f)

    print(f"Training new model")
    model = AlbertForSequenceClassification.from_pretrained(model_name, num_labels=len(unique_labels))

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
    )
    trainer.train()


print(model)

saving training dataset
Training new model


Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at twmkn9/albert-base-v2-squad2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/207 [03:32<?, ?it/s]
                                                 
100%|██████████| 207/207 [48:01<00:00, 13.92s/it]

{'train_runtime': 2881.6198, 'train_samples_per_second': 0.576, 'train_steps_per_second': 0.072, 'train_loss': 0.4922458814538043, 'epoch': 2.99}
AlbertForSequenceClassification(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertSdpaAttention(
                (query): Linear(in_features=768, out_features=768, bia




In [40]:
# Check the number of unique labels in your dataset
print(f"Unique Labels: {unique_labels}")
print(f"Number of Classes (num_labels): {len(unique_labels)}")
# Convert labels to integers from 0 to len(unique_labels) - 1
label_ids = torch.tensor([unique_labels.index(lbl) for lbl in labels])

# Verify label IDs are within range
print(f"Label IDs: {label_ids}")
print(f"Max Label ID: {label_ids.max()}, Expected: {len(unique_labels) - 1}")


Unique Labels: ['CDR', 'Investigation', 'Insight', 'Phone', 'Report', 'Web Actor', 'Person', 'EVisa Request', 'Web Activity']
Number of Classes (num_labels): 9
Label IDs: tensor([0, 8, 1, 2, 8, 3, 8, 2, 3, 4, 4, 0, 8, 0, 8, 6, 3, 4, 4, 1, 4, 6, 0, 1,
        5, 4, 6, 3, 0, 7, 0, 8, 5, 0, 6, 0, 6, 2, 2, 5, 2, 4, 3, 8, 6, 5, 0, 2,
        3, 1, 0, 1, 1, 0, 0, 4, 4, 2, 5, 2, 8, 3, 0, 7, 8, 1, 0, 4, 3, 0, 8, 0,
        0, 0, 5, 6, 0, 1, 1, 0, 0, 4, 8, 8, 8, 5, 8, 5, 0, 5, 3, 8, 3, 4, 0, 1,
        1, 0, 3, 2, 8, 8, 0, 1, 6, 3, 8, 8, 4, 0, 0, 7, 5, 8, 7, 0, 4, 1, 6, 8,
        3, 5, 0, 3, 5, 6, 8, 8, 3, 2, 8, 2, 6, 8, 6, 8, 8, 0, 0, 8, 0, 0, 5, 5,
        5, 8, 4, 6, 4, 3, 0, 0, 4, 8, 8, 0, 5, 0, 8, 1, 8, 6, 0, 5, 5, 2, 0, 0,
        8, 0, 6, 4, 8, 8, 0, 0, 3, 1, 1, 8, 3, 5, 5, 1, 3, 1, 5, 2, 0, 1, 0, 0,
        4, 7, 3, 4, 0, 8, 2, 0, 3, 3, 0, 0, 0, 0, 8, 0, 4, 8, 6, 8, 4, 5, 6, 2,
        8, 1, 8, 0, 4, 2, 5, 6, 8, 1, 1, 6, 4, 8, 7, 0, 1, 0, 0, 0, 0, 5, 5, 4,
        3, 0, 3, 5, 0, 6, 0, 

In [41]:
should_save = False
if should_save:
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model and tokenizer saved to {output_dir}")
else:
    print(f"No Trainer that's ok. We just won't save")


No Trainer that's ok. We just won't save


In [42]:
# evaluation
from sklearn.metrics import classification_report, f1_score

def evaluate_model(trainer, val_dataset, unique_labels):
    # get predictions
    preds = trainer.predict(val_dataset)

    # covert ML output to labels
    preds_labels = torch.argmax(torch.tensor(preds.predictions), dim=1).numpy()

    # extract TRUE labels
    true_labels = [val_dataset[i]['labels'].item() for i in range(len(val_dataset))]

    #compute F1 score
    f1 = f1_score(true_labels, preds_labels, average='weighted')

    # print report
    print(f"weighted f1 score: {f1}")
    print("Classification Report:\n")
    print(classification_report(true_labels, preds_labels, target_names=unique_labels))


In [44]:
evaluate_model(trainer, val_dataset, unique_labels)

77it [01:44,  1.36s/it]                        

weighted f1 score: 1.0
Classification Report:

               precision    recall  f1-score   support

          CDR       1.00      1.00      1.00        37
Investigation       1.00      1.00      1.00        15
      Insight       1.00      1.00      1.00        11
        Phone       1.00      1.00      1.00        16
       Report       1.00      1.00      1.00        14
    Web Actor       1.00      1.00      1.00        14
       Person       1.00      1.00      1.00        10
EVisa Request       1.00      1.00      1.00         2
 Web Activity       1.00      1.00      1.00        20

     accuracy                           1.00       139
    macro avg       1.00      1.00      1.00       139
 weighted avg       1.00      1.00      1.00       139






In [45]:
# inference on trained model
def infer(model, tokenizer, query, unique_labels):
    # Tokenize the input query
    inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)

    # Perform inference using the model
    outputs = model(**inputs)

    # Get the predicted label ID
    predicted_label_id = torch.argmax(outputs.logits, dim=1).item()

    # Convert the label ID back to the original label name
    predicted_label = unique_labels[predicted_label_id]

    return predicted_label


In [56]:
# Test the inference function with a sample query
sample_query = "Which SMS were rejected?"
predicted_entity = infer(model, tokenizer, sample_query, unique_labels)

print(f"Predicted Entity for Query: {sample_query}")
print(f"Predicted Entity Type: {predicted_entity}")

Predicted Entity for Query: Which SMS were rejected?
Predicted Entity Type: Phone
