In [1]:
import pandas as pd

# Load the CSV file
input_file = "legal_aid_queries.csv"  # Replace with the path to your file
output_file = "legal_aid_queries_processed.csv"

# Read the file, skipping the first two lines
with open(input_file, 'r') as file:
    lines = file.readlines()[2:]  # Skip the first two lines

# Split each line into category and query
data = []
for line in lines:
    if line.strip():  # Skip empty lines
        split_idx = line.find(']')  # Find the closing bracket
        if split_idx != -1:
            category = line[1:split_idx].strip()  # Extract text within brackets
            query = line[split_idx + 1:].strip()  # Extract remaining text
            data.append({'category': category, 'query': query})

# Create a DataFrame
df = pd.DataFrame(data)

# Save to a new CSV file
df.to_csv(output_file, index=False)



In [2]:
df = pd.read_csv("legal_aid_queries_processed.csv")

In [3]:
df.head()

Unnamed: 0,category,query
0,Negative,How do I apply for a driver's license in Illin...
1,Negative,What's the best time to visit Navy Pier?
2,Misspelled,Can I get unemployement benefits in Illinois?
3,Misspellings,How do I get a paternit test?
4,Legal-Non-Legal,How do I evict a roomate who's not paying rent?


In [8]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch

# Load the dataset
df = pd.read_csv("legal_aid_queries_processed.csv")

# Filter categories with at least 2 examples
category_counts = df['category'].value_counts()
valid_categories = category_counts[category_counts >= 2].index
df = df[df['category'].isin(valid_categories)]

# Map categories to numerical labels
categories = sorted(df['category'].unique())  # Ensure consistent ordering
category_to_id = {cat: idx for idx, cat in enumerate(categories)}
id_to_category = {idx: cat for cat, idx in category_to_id.items()}
df['label'] = df['category'].map(category_to_id)

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    df['query'], df['label'], test_size=0.2, random_state=42
)

# Tokenize using DistilBERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

class QueryDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
        self.encodings = tokenizer(list(texts), truncation=True, padding=True, max_length=128)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

# Create datasets
train_dataset = QueryDataset(X_train, y_train.tolist())
test_dataset = QueryDataset(X_test, y_test.tolist())

# Load the DistilBERT model with classification head
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased', num_labels=len(categories)
)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    save_total_limit=2,
)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()

# Evaluate the model
metrics = trainer.evaluate()
print("Evaluation Metrics:", metrics)

# Save the trained model
model_dir = "./distilbert_model"
trainer.save_model(model_dir)
print(f"Model saved to {model_dir}")

# Inference Example
def predict_category(query):
    encoding = tokenizer(query, return_tensors='pt', truncation=True, padding=True, max_length=128)
    outputs = model(**encoding)
    pred_label = torch.argmax(outputs.logits, dim=1).item()
    return id_to_category[pred_label]

# Test the prediction function
example_query = "How do I file for divorce?"
predicted_category = predict_category(example_query)
print(f"Predicted Category: {predicted_category}")





tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,1.7672,1.635587
2,0.9454,0.999612
3,0.7595,0.809981


Evaluation Metrics: {'eval_loss': 0.8099806904792786, 'eval_runtime': 9.4034, 'eval_samples_per_second': 64.977, 'eval_steps_per_second': 1.063, 'epoch': 3.0}
Model saved to ./distilbert_model
Predicted Category: Simple


In [13]:
def predict_category(query):
    encoding = tokenizer(query, return_tensors='pt', truncation=True, padding=True, max_length=128)
    outputs = model(**encoding)
    pred_label = torch.argmax(outputs.logits, dim=1).item()
    return id_to_category[pred_label]

# Test the prediction function
example_query = "What are my rights and my dog ran away"
predicted_category = predict_category(example_query)
print(f"Predicted Category: {predicted_category}")

Predicted Category: Mixed
