## Sequence Classification for anonymization

This notebook is an experiment to classify key-value fields from JSON as being anonymized or not for PII data.

The first step is to create our model name and get a tokenizer to tokenize our training data

In [None]:
from transformers import AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# The llama3 tokenizer doesn't do padding like other models.  So set them as End of Sequence
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

## Import tab_exp module

We need the tab_exp module to generate the synthetic test data

In [None]:
from datasets import load_dataset

import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
from tab_exp.tab import generate_synth_data, PIIData

## Create dataset

We will use the generate_synth_data to create datasets for training, testing, and validation

In [None]:
from datasets import DatasetDict

sd_train = generate_synth_data(samples=1000, output="samples_train", clean=True)
ds_train = load_dataset("json", data_files=sd_train["combined_path"], split="train")

sd_test = generate_synth_data(samples=300, output="samples_test", clean=True)
ds_test = load_dataset("json", data_files=sd_test["combined_path"])

sd_validate = generate_synth_data(samples=100, output="samples_validate", clean=True)
ds_validate = load_dataset("json", data_files=sd_validate["combined_path"])

dataset = DatasetDict({
    "train": ds_train,
    "validate": ds_validate,
    "test": ds_test
})

## Quantize for efficiency

For experimentation on a small computer, we need to quantize the weights to make them smaller to trade accuracy for
computational speed.  Without this, we either will not be able to fit the model weights at all into memory, or it will
take forever to finish

In [None]:
from transformers import BitsAndBytesConfig, AutoModelForSequenceClassification
import torch

if not torch.cuda.is_available():
    raise Exception("GPU must be available for trainin")

quantization_config = BitsAndBytesConfig(
    load_in_4bit = True, 
    bnb_4bit_quant_type = 'nf4',
    bnb_4bit_use_double_quant = True, 
    bnb_4bit_compute_dtype = torch.bfloat16 
)
# Create a model for text classification.  Normally llama3 is used for CausalLLM (question/answer)
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    num_labels=4,
    device_map='auto'
)

## Use LoRA to train only a subset of the weights

Fine tuning all the weights of the checkpoint would be too prohibitive.  So we will use LoRA to train only a subset

In [None]:
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

lora_config = LoraConfig(
    r = 16, 
    lora_alpha = 8,
    target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    lora_dropout = 0.05, 
    bias = 'none',
    task_type = 'SEQ_CLS'
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [None]:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
model.config.pretraining_tp = 1

## Tokenize the dataset

We need to convert the natural language in the dataset to the embeddings needed by the LLM

In [None]:
from typing import cast
from datasets import Dataset
from transformers import DataCollatorWithPadding

# function that will be applied to the testing data.  We need to tokenize it for training
def tokenize_fn(data: PIIData):
    return tokenizer(data['text'], truncation=True, padding=True)

tokenized_train_ds: Dataset = cast(Dataset, ds_train.map(tokenize_fn, batched=True, remove_columns=["text"]))
tokenized_train_ds.set_format("torch")
# pad the batch of inputs to a length equal to the maximum input length in that batch
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer)


In [None]:
import json
with open(sd_train["combined_path"], "r") as cf:
    text: list[str] = [json.loads(line)["text"] for line in cf.readlines()]

text = text[:128]

accel = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 16

all_outputs = []

for i in range(0, len(text), batch_size):
    batched_inputs = text[i:i + batch_size]
    
    inputs = tokenizer(batched_inputs, truncation=True, padding=True, return_tensors="pt", max_length=512)
    inputs = {k: v.to(accel) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        all_outputs.append(outputs['logits'])
    print(all_outputs)

In [None]:
all_outputs

## Evaluate performance

We need a way to evaluate the performance.  By default HF will use 

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import balanced_accuracy_score, classification_report

def get_metrics_result(test_df):
    y_test = test_df.label
    y_pred = test_df.predictions

    print("Classification Report:")
    print(classification_report(y_test, y_pred))

    print("Balanced Accuracy Score:", balanced_accuracy_score(y_test, y_pred))
    print("Accuracy Score:", accuracy_score(y_test, y_pred))

def compute_metrics(evaluations):
    predictions, labels = evaluations
    predictions = np.argmax(predictions, axis=1)
    return {'balanced_accuracy' : balanced_accuracy_score(predictions, labels),
    'accuracy':accuracy_score(predictions,labels)}

## Create a custom trainer

We will create a custom trainer

In [None]:
from transformers import Trainer
import torch.nn.functional as F

class CustomTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            self.class_weights = torch.tensor(class_weights, dtype=torch.float32).to(self.args.device)
        else:
            self.class_weights = None

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels").long()
        outputs = model(**inputs)
        logits = outputs.get('logits')

        if self.class_weights is not None:
            loss = F.cross_entropy(logits, labels, weight=self.class_weights)
        else:
            loss = F.cross_entropy(logits, labels)

        return (loss, outputs) if return_outputs else loss