In [8]:
import pandas as pd
from transformers import BertTokenizer
from torch.utils.data import DataLoader, TensorDataset
import torch

def preprocess_for_bert(texts, labels=None, max_length=128, batch_size=16):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # Tokenize and encode the texts
    encoded = tokenizer(
        texts,
        add_special_tokens=True,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )
    
    input_ids = encoded["input_ids"]
    attention_masks = encoded["attention_mask"]
    
    if labels is not None:
        labels_tensor = torch.tensor(labels)
        dataset = TensorDataset(input_ids, attention_masks, labels_tensor)
    else:
        dataset = TensorDataset(input_ids, attention_masks)
    
    # Create DataLoader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader, input_ids, attention_masks, labels if labels is not None else None

# 1. Load data from CSV
file_path = "most_frequent_severity_with_def_text.csv"
data = pd.read_csv(file_path)

# 2. Extract columns
texts = data["def_text"].tolist()
labels = data["Most_Frequent_Severity"].tolist()

# 3. Label-encode the string labels
unique_labels = ['Low', 'Medium', 'High']
label2id = {label: idx for idx, label in enumerate(unique_labels)}
labels = [label2id[label] for label in labels]

# 4. Preprocess data
dataloader, input_ids, attention_masks, labels = preprocess_for_bert(texts, labels)

# 5. Save processed data to a CSV file
processed_data = {
    "input_ids": [list(ids.numpy()) for ids in input_ids],
    "attention_masks": [list(mask.numpy()) for mask in attention_masks],
    "labels": labels if labels is not None else None
}

# Convert to DataFrame
processed_df = pd.DataFrame(processed_data)

# Save to CSV
processed_file_path = "processed_severity_data.csv"
processed_df.to_csv(processed_file_path, index=False)

print(f"Processed data saved to {processed_file_path}")

Processed data saved to processed_severity_data.csv
