In [1]:
!pip install transformers torch pandas
!pip install sentencepiece
!pip install scikit-learn
# needed for training
! pip install -U accelerate


Defaulting to user installation because normal site-packages is not writeable


In [None]:
# load data and create mapping into new dataframe
# right now we are just using the user query 
### TODO files are static paths now.  we need to make them dynamic and maybe add a nice UI to select the file
import pandas as pd

fields_desc = pd.read_csv('fields_description.csv')
user_queries = pd.read_csv('user_queries.csv')

print(fields_desc.head())
print(user_queries.head())

# Create a dictionary mapping entity names to their field descriptions and properties
# This groups the data by entity_name and creates a nested dictionary structure for easy access to field information for each entity type
entity_to_field_mapping = fields_desc.groupby('entity_name').apply(lambda x: x[['field_name', 'description']].to_dict(orient='records')).to_dict()

print('test sample mapping')
print(entity_to_field_mapping.get('Phone', []))


  entity_name                   field_name field_type  \
0         CDR      ifc.ootb.CDR.callStatus     string   
1         CDR             ifc.CDR.caseCode     string   
2         CDR            ifc.CDR.chatTopic     string   
3         CDR  ifc.ootb.CDR.createDateTime       date   
4         CDR       ifc.ootb.CDR.direction     string   

                                         description  
0  Status of the call: "Successful", "Failed", "B...  
1            Unique code identifying a specific case  
2         Topic or subject of discussion in the chat  
3                  Date and time of record creation.  
4         Direction of the call (incoming, outgoing)  
                                            question  \
0           Find all calls made using 3G technology.   
1  List all Reddit comments posted yesterday with...   
2  Show me investigations that are either open or...   
3  Find all insights related to the witness Jane ...   
4  List all web activities updated in the last 

  entity_to_field_mapping = fields_desc.groupby('entity_name').apply(lambda x: x[['field_name', 'description']].to_dict(orient='records')).to_dict()


In [None]:
# prepare our data for training. we combine our user query with field description
import json
import re

def clean_json_string(json_string):
    # Remove any leading/trailing whitespace
    json_string = json_string.strip()
    
    # Ensure the string is enclosed in curly braces
    if not json_string.startswith('{'):
        json_string = '{' + json_string
    if not json_string.endswith('}'):
        json_string = json_string + '}'
    
    # Replace single quotes with double quotes, but not within values
    json_string = re.sub(r"(?<!\\)'", '"', json_string)
    
    # Remove any trailing commas before closing braces or brackets
    json_string = re.sub(r',\s*([\]}])', r'\1', json_string)
    
    return json_string

def prepare_data_for_training(user_query, fields_mapping):
    inputs, labels = [], []

    for _, row in user_queries.iterrows():
        query = row['question']
        cleaned_json_string = clean_json_string(row['json'])
        
        try:
            json_data = json.loads(cleaned_json_string)
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON for query: {query}")
            print(f"Error: {e}")
            continue  # Skip this row and continue with the next one
        
        # extract entity types and relation target types
        entity_type = json_data.get('entityType', '')
        relation_type = json_data.get('relationTargetType', '')

        # get the description for each entity type
        fields = fields_mapping.get(entity_type, [])
        field_descriptions = ';'.join([f"{field['field_name']}: {field['description']}" for field in fields])

        #combine query with descriptions
        input_text = f"Query: {query}. Entity: {entity_type}. Fields: {field_descriptions}"
        inputs.append(input_text)
        labels.append(entity_type if not relation_type else f"{entity_type}|{relation_type}")

    return inputs, labels

# prepare data
inputs, labels = prepare_data_for_training(user_queries, entity_to_field_mapping)

Error decoding JSON for query: Which phones have been marked as suspicious?
Error: Expecting value: line 1 column 150 (char 149)
Error decoding JSON for query: What failed call attempts were made from target phones to numbers containing '1234'?
Error: Expecting value: line 1 column 228 (char 227)
Error decoding JSON for query: Which phones are set on Arabic and are marked as suspicious?
Error: Expecting value: line 1 column 280 (char 279)
Error decoding JSON for query: List emails sent to phones associated with the target Sarah Johnson
Error: Expecting value: line 1 column 337 (char 336)
Error decoding JSON for query: List all emails from john@company.com to jane@company.com with attachments in February 2024
Error: Expecting value: line 1 column 646 (char 645)
Error decoding JSON for query: Show me any insights related to the interview with the victim's family yesterday.
Error: Expecting ',' delimiter: line 1 column 152 (char 151)
Error decoding JSON for query: Find all emails about 'm

In [None]:
for i in range(3):
    print(f"Input {i+1}: {inputs[i]}")
    print(f"Label {i+1}: {labels[i]}")

Input 1: Query: Find all calls made using 3G technology.. Entity: CDR. Fields: ifc.ootb.CDR.callStatus: Status of the call: "Successful", "Failed", "Blocked", or "Redirected";ifc.CDR.caseCode: Unique code identifying a specific case;ifc.CDR.chatTopic: Topic or subject of discussion in the chat;ifc.ootb.CDR.createDateTime: Date and time of record creation.;ifc.ootb.CDR.direction: Direction of the call (incoming, outgoing);ifc.ootb.CDR.duration: Duration of the communication in minutes. You can ask it For example: 1min -> 60;ifc.CDR.emailSubject: Subject of the email communication;ifc.ootb.CDR.endTime: Time when the communication ended;ifc.ootb.CDR.hasContent: Indicates if the communication has content;ifc.ootb.CDR.imei: this field is intended to store the IMEI number of the device who made a call.;ifc.ootb.CDR.imei2: this field is intended to store the IMEI number of the device that is receiving the call. It serves the same purposes as the caller's IMEI but for the receiving side of the

In [None]:
# train test split
from sklearn.model_selection import train_test_split

train_inputs, val_inputs, train_labels, val_labels = train_test_split(
    inputs, labels, test_size=0.2, stratify=labels, random_state=42
)

print(f"Training size: {len(train_inputs)}, Validation size: {len(val_inputs)}")

Training size: 553, Validation size: 139


In [None]:
from transformers import AlbertTokenizer
import torch

model_name = "twmkn9/albert-base-v2-squad2"
tokenizer = AlbertTokenizer.from_pretrained(model_name)
train_encodings = tokenizer(train_inputs, truncation=True, padding=True, return_tensors="pt")
val_encodings = tokenizer(val_inputs, truncation=True, padding=True, return_tensors="pt")

# labels to tensors. we can't have names only machine values
# remove duplicates and iterate through to assign a number
unique_labels = list(set(labels))
train_labels_tensor = torch.tensor([unique_labels.index(lbl) for lbl in train_labels])
val_labels_tensor = torch.tensor([unique_labels.index(lbl) for lbl in val_labels])


In [None]:
# do our dimensions match?
print(f"Training encodings: {train_encodings['input_ids'].shape}, Labels: {train_labels_tensor.shape}")
print(f"Validation encodings: {val_encodings['input_ids'].shape}, Labels: {val_labels_tensor.shape}")

Training encodings: torch.Size([553, 512]), Labels: torch.Size([553])
Validation encodings: torch.Size([139, 512]), Labels: torch.Size([139])


In [None]:
# prepare results folder
import os
from datetime import datetime

# Create a unique output directory
base_output_dir = "./results"
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(base_output_dir, f"run_{current_time}")

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

In [None]:
# model and training run
from transformers import TrainingArguments, Trainer
from transformers import AlbertForSequenceClassification
import torch

model = AlbertForSequenceClassification.from_pretrained(model_name)

training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset={
        'input_ids': train_encodings['input_ids'],
        'attention_mask': train_encodings['attention_mask'],
        'labels': train_labels_tensor['input_ids']    
    },
    eval_dataset={
        'input_ids': val_encodings['input_ids'],
        'attention_mask': val_encodings['attention_mask'],
        'labels': val_labels_tensor['input_ids']
    },
    tokenizer=tokenizer,
)

Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at twmkn9/albert-base-v2-squad2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


IndexError: too many indices for tensor of dimension 1

In [None]:
print(model)

AlbertForQuestionAnswering(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertSdpaAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, 

In [25]:
# running the big boy
trainer.train()

AttributeError: `AcceleratorState` object has no attribute `distributed_type`. This happens if `AcceleratorState._reset_state()` was called and an `Accelerator` or `PartialState` was not reinitialized.

In [5]:
# Sample JSON-like data (you'll replace this with your CSV data)
json_data = [
    {"entityType": "CDR", "relationTargetType": "Phone"},
    {"entityType": "Report", "relationTargetType": "Malware"}
]

# Example query from the user
query = "What SMS messages were sent from suspicious phones to 0549876543 containing 'urgent'?"


In [12]:
#Function to search for relevant entitties in teh JSON data
def find_matching_entities(query, json_data):
    matching_entities = []

    for record in json_data:
        entity_text = f"{record['entityType']} {record['relationTargetType']}"

        #encode inputs for model
        inputs = tokenizer(query, entity_text, return_tensors="pt")

        #run the model to get answer scores
        with torch.no_grad():
            outputs = model(**inputs)

        #get the start and end scores for the answer
        answer_start = torch.argmax(outputs.start_logits)
        answer_end = torch.argmax(outputs.end_logits) + 1

        #extract the answer
        predicted_entity = tokenizer.convert_tokens_to_string(
            tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])
        )

        #if predicted entity is not empty, consider it a match
        if predicted_entity.strip():
            return record['entityType'], record['relationTargetType']

    return None
    # return list(set(matching_entities)) #remove duplicates

In [14]:
#Test example
matching_entities = list(set(find_matching_entities(query, json_data)))
print(matching_entities)

['Phone', 'CDR']
