In [None]:
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
import numpy as np
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score, f1_score, hamming_loss, matthews_corrcoef, cohen_kappa_score
from transformers import EvalPrediction, TrainingArguments, Trainer
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig

# Load and preprocess data
df = pd.read_csv('drive/MyDrive/preprocessed.csv')
rng_seed = 100
df_randomized_order = df.sample(frac=1, random_state=rng_seed)
df_randomized_order = df_randomized_order[df_randomized_order["section_code"] != "-"]
df_randomized_order['section_code'] = df_randomized_order['section_code'].str.split(',')

# Multi-label binarization
multilabel = MultiLabelBinarizer()
labels = multilabel.fit_transform(df_randomized_order['section_code']).astype('float32')
texts = df_randomized_order['abstracted_heading_plus_content'].tolist()

# Balancing the dataset
def balance_dataset(texts, labels):
    total_instances = len(texts)
    balanced_texts = []
    balanced_labels = []

    for i in range(labels.shape[1]):
        class_indices = np.where(labels[:, i] == 1)[0]
        class_texts = [texts[j] for j in class_indices]
        class_labels = labels[class_indices]

        num_dup = total_instances // len(class_texts)
        balanced_class_texts = class_texts * num_dup
        balanced_class_labels = np.tile(class_labels, (num_dup, 1))

        num_add = total_instances % len(class_texts)
        resampled_texts = resample(class_texts, n_samples=num_add, random_state=0)
        resampled_labels = resample(class_labels, n_samples=num_add, random_state=0)

        balanced_class_texts.extend(resampled_texts)
        balanced_class_labels = np.vstack((balanced_class_labels, resampled_labels))

        balanced_texts.extend(balanced_class_texts)
        balanced_labels.append(balanced_class_labels)

    balanced_labels = np.vstack(balanced_labels)
    return balanced_texts, balanced_labels

balanced_texts, balanced_labels = balance_dataset(texts, labels)

# Train-test split
train_texts, val_texts, train_labels, val_labels = train_test_split(balanced_texts, balanced_labels, test_size=0.3, random_state=42)

# Load tokenizer and model
checkpoint = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)
model = DistilBertForSequenceClassification.from_pretrained(checkpoint, num_labels=labels.shape[1])

# Custom dataset
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = torch.tensor(self.labels[idx])
        encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_len, return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': label
        }

train_dataset = CustomDataset(train_texts, train_labels, tokenizer)
val_dataset = CustomDataset(val_texts, val_labels, tokenizer)

def multi_labels_metrics(predictions, labels, threshold=0.3):
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    y_true = labels

    f1_macro = f1_score(y_true, y_pred, average='macro')
    f1_micro = f1_score(y_true, y_pred, average='micro')
    roc_auc_macro = roc_auc_score(y_true, y_pred, average='macro')
    roc_auc_micro = roc_auc_score(y_true, y_pred, average='micro')
    hamming = hamming_loss(y_true, y_pred)

    # Label-wise weighted MCC calculation
    num_labels = y_true.shape[1]
    mcc_weighted = 0
    total_weight = 0

    for i in range(num_labels):
        y_true_label = y_true[:, i]
        y_pred_label = y_pred[:, i]

        # Weight by the proportion of positive instances for the label
        weight = float(y_true_label.sum()) / y_true_label.shape[0]
        total_weight += weight
        mcc_weighted += weight * matthews_corrcoef(y_true_label, y_pred_label)

    # Normalize by total weight
    if total_weight > 0:
        mcc_weighted /= total_weight

    # Weighted Kappa calculation (if needed)
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()
    kappa_weighted = cohen_kappa_score(y_true_flat, y_pred_flat, weights="quadratic")

    metrics = {
        "roc_auc_macro": roc_auc_macro,
        "roc_auc_micro": roc_auc_micro,
        "hamming_loss": hamming,
        "f1_macro": f1_macro,
        "f1_micro": f1_micro,
        "mcc_weighted": mcc_weighted,
        "kappa_weighted": kappa_weighted
    }
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    result = multi_labels_metrics(predictions=preds, labels=p.label_ids)
    return result

peft_config = LoraConfig(task_type="SEQ_CLS",
                         r=4,
                         lora_alpha=32,
                         lora_dropout=0.01,
                         target_modules=['query'])

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Training Arguments without saving the model
args = TrainingArguments(
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    output_dir='./drive/MyDrive/fit',
    num_train_epochs=6,
    evaluation_strategy="epoch",
    logging_dir='./drive/MyDrive/logs',
    save_strategy="epoch"
)

# Trainer initialization
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)
