In [None]:
!pip3 install seqeval evaluate rjieba wandb pytorch-lightning transformers==4.37.2 datasets==2.17.0 sentensepiece -U

In [None]:
import argparse
import json
from itertools import chain
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from datasets import Dataset as DS

from transformers import (
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    AutoTokenizer
)

In [None]:
class CFG:
    pretrained_model_name = "microsoft/deberta-v3-large"
    training_max_length = 512
    base_path = "./"
    output_dir = base_path+"output"
    ds_path  = base_path+"/train.json"
    seed = 42
    batch_size = 4

In [None]:
pl.seed_everything(CFG.seed)

In [None]:

ds_path = "."
data = json.load(open(f"{ds_path}/train.json"))

# downsampling of negative examples
p=[] # positive samples (contain relevant labels)
n=[] # negative samples (presumably contain entities that are possibly wrongly classified as entity)
for d in data:
    if any(np.array(d["labels"]) != "O"): p.append(d)
    else: n.append(d)
print("original datapoints: ", len(data))

external = json.load(open(f"{ds_path}/pii_dataset_fixed.json"))
print("external datapoints: ", len(external))

moredata = json.load(open(f"{ds_path}/moredata_dataset_fixed.json"))
print("moredata datapoints: ", len(moredata))

data = external+moredata+p+n[:len(n)//3]
print("combined: ", len(data))

In [None]:
all_labels = sorted(list(set(chain(*[x["labels"] for x in data]))))
label2id = {l: i for i,l in enumerate(all_labels)}
id2label = {v:k for k,v in label2id.items()}

target = [
    'B-EMAIL', 'B-ID_NUM', 'B-NAME_STUDENT', 'B-PHONE_NUM', 
    'B-STREET_ADDRESS', 'B-URL_PERSONAL', 'B-USERNAME', 'I-ID_NUM', 
    'I-NAME_STUDENT', 'I-PHONE_NUM', 'I-STREET_ADDRESS', 'I-URL_PERSONAL'
]

print(id2label)

In [None]:
def tokenize(example, tokenizer, label2id):
    text = []

    # these are at the character level
    labels = []
    targets = []

    for t, l, ws in zip(example["tokens"], example["provided_labels"], example["trailing_whitespace"]):

        text.append(t)
        labels.extend([l]*len(t))
        
        if l in target:
            targets.append(1)
        else:
            targets.append(0)
        # if there is trailing whitespace
        if ws:
            text.append(" ")
            labels.append("O")

    tokenized = tokenizer("".join(text), return_offsets_mapping=True, truncation=True, max_length=CFG.training_max_length)
    
    target_num = sum(targets)
    labels = np.array(labels)

    text = "".join(text)
    token_labels = []

    for start_idx, end_idx in tokenized.offset_mapping:

        # CLS token
        if start_idx == 0 and end_idx == 0: 
            token_labels.append(label2id["O"])
            continue

        # case when token starts with whitespace
        if text[start_idx].isspace():
            start_idx += 1

        token_labels.append(label2id[labels[start_idx]])

    length = len(tokenized.input_ids)

    return {
        **tokenized,
        "labels": token_labels,
        "length": length,
        "target_num": target_num,
        "group": 1 if target_num>0 else 0
    }


In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG.pretrained_model_name)
# tokenizer.save_pretrained("./outputs")

In [None]:
ds = DS.from_dict({
    "full_text": [x["full_text"] for x in data],
    "document": [str(x["document"]) for x in data],
    "tokens": [x["tokens"] for x in data],
    "trailing_whitespace": [x["trailing_whitespace"] for x in data],
    "provided_labels": [x["labels"] for x in data],
})

In [None]:
%%time

ds  = ds.map(tokenize, fn_kwargs={
    "tokenizer": tokenizer,
    "label2id": label2id
}, num_proc = 6)

ds.class_encode_column("group")

In [None]:
ds = ds.train_test_split(test_size=0.2, seed=CFG.seed)

In [None]:
collator_fn = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=512)

In [None]:
cols_to_remove = ['full_text', 'document', 'tokens', 'trailing_whitespace', 'provided_labels', 'offset_mapping', 'length', 'target_num', 'group']

def get_dataset(dataset, data_type="train"):
    data = dataset[data_type]
    data = data.remove_columns(cols_to_remove)
    data = data.with_format("torch")
    return data

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

def _configure_optimizer(lr, epochs, weight_decay, params):
    "Prepare optimizer and schedule (linear warmup and decay)"        
    model_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, params), 
        lr=lr,
        weight_decay=weight_decay
    )
    lr_scheduler = CosineAnnealingWarmRestarts(
                        model_optimizer, 
                        T_0=epochs, 
                        T_mult=1, 
                        eta_min=1e-6, 
                        last_epoch=-1
                    )
    interval = "epoch"
    return {
    "optimizer": model_optimizer, 
    "lr_scheduler": {
        "scheduler": lr_scheduler,
        "interval": interval,
        "monitor": "val_loss",
        "frequency": 1
    }
}

In [None]:
k = 2
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',  # Replace with your validation metric
    mode='min',          # 'min' if the metric should be minimized (e.g., loss), 'max' for maximization (e.g., accuracy)
    save_top_k=k,        # Save top k checkpoints based on the monitored metric
    save_last=True,      # Save the last checkpoint at the end of training # Directory where the checkpoints will be saved
    filename='{epoch}-{train_loss:.2f}'  # Checkpoint file naming pattern
)

In [None]:
from seqeval.metrics import recall_score, precision_score

def compute_metrics(p, all_labels):
    predictions, labels = p
    predictions = torch.argmax(predictions, axis=2).cpu().numpy()
    labels = labels.cpu().numpy()

    # Remove ignored index (special tokens)
    true_predictions = [
        [all_labels[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [all_labels[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    
    recall = recall_score(true_labels, true_predictions)
    precision = precision_score(true_labels, true_predictions)
    f1_score = (1 + 5*5) * recall * precision / (5*5*precision + recall)
    
    results = {
        'recall': recall,
        'precision': precision,
        'f1': f1_score
    }
    return results

In [None]:
backbone = AutoModelForTokenClassification.from_pretrained(
            CFG.pretrained_model_name,
            num_labels =len(all_labels) ,
            id2label=id2label, label2id=label2id
            )

In [None]:
# backbone.save_pretrained("./outputs")

In [None]:
class TokenFineTuner(pl.LightningModule):
    def __init__(self, model,  hparam, dataset):
        super(TokenFineTuner, self).__init__()
        self.hparam = hparam
        self.num_labels = len(id2label.keys())
        self.model = model
        self.ds = dataset
        self.save_hyperparameters()
        self.validation_step_outputs = []

    def forward(
        self, input_ids, attention_mask=None, lm_labels=None
    ):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=lm_labels,
        )

    def _step(self, batch):
        labels = batch["labels"]
        output = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            lm_labels=labels
            )
    
        return {"loss": output.loss, "logits": output.logits}

    def training_step(self, batch, batch_idx):
        loss = self._step(batch=batch)["loss"]
        
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self._step(batch)
        loss = output["loss"]
        self.validation_step_outputs.append({
            **output,
            "targets": batch["labels"]
        })
        self.log("val_loss", loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return _configure_optimizer(lr=self.hparam.learning_rate, 
                                    epochs=self.hparam.num_train_epochs,
                                    weight_decay=self.hparam.weight_decay,
                                    params=self.parameters())

    def train_dataloader(self):
        data = get_dataset(self.ds)
        dataloader = DataLoader(data, batch_size=self.hparam.train_batch_size, collate_fn=collator_fn,
                                shuffle=True, num_workers=4)
        return dataloader

    def val_dataloader(self):
        val_dataset = get_dataset(self.ds, data_type="test")
        return DataLoader(val_dataset, batch_size=self.hparam.eval_batch_size, collate_fn=collator_fn, num_workers=4,shuffle=True)
    
    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        output_val = nn.Softmax(dim=1)(torch.cat([x['logits'] for x in outputs],dim=0))
        target_val = torch.cat([x['targets'] for x in outputs],dim=0)
        avg_score = compute_metrics((output_val, target_val), all_labels=list(label2id.keys()))
        self.log("val_f5", avg_score["f1"],on_epoch=True, prog_bar=True)
        self.log("val_precision", avg_score["precision"],on_epoch=True, prog_bar=True)
        self.log("val_recall", avg_score["recall"],on_epoch=True, prog_bar=True)
        return {'val_loss': avg_loss,'val_cmap':avg_score}

In [None]:
args_dict = dict(
    output_dir="./", # path to save the checkpoints
    model_name_or_path='microsoft/deberta-v3-large',
    tokenizer_name_or_path='microsoft/deberta-v3-large',
    max_seq_length=256,
    learning_rate=3e-4,
    weight_decay=1e-2,
    adam_epsilon=1e-8,
    warmup_steps=0,
    train_batch_size=8,
    eval_batch_size=4,
    num_train_epochs=100,
    gradient_accumulation_steps=16,
    n_gpu=1,
    early_stop_callback=False,
    fp_16=True, # if you want to enable 16-bit training then install apex and set this to true
    opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
    max_grad_norm=1, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
    seed=42,
)

In [None]:
args = argparse.Namespace(**args_dict)
model = TokenFineTuner(backbone, args, ds)

In [None]:
wandb_logger = WandbLogger(project="PIDD")

In [None]:
train_params = dict(
    accumulate_grad_batches=args.gradient_accumulation_steps,
    devices = 1,
    max_epochs=args.num_train_epochs,
    #early_stop_callback=False,
    precision= 16 if args.fp_16 else 32,
    gradient_clip_val=args.max_grad_norm,
    logger=wandb_logger,
    callbacks = [checkpoint_callback]
)

In [None]:
trainer = pl.Trainer(**train_params)

In [None]:
torch.set_float32_matmul_precision('medium')
trainer.fit(model)

In [None]:
# trainer.save_checkpoint("./outputs/model.ckpt")