In [1]:
import os
os.chdir("../")

# 🏋️ PII Model Training Notebook

## 📦 Imports

From Packages

In [2]:
from itertools import chain
from functools import partial
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
)
import pandas as pd
from types import SimpleNamespace
import torch
import wandb
import spacy

From utility scripts

In [3]:
from src.metric import (
    compute_metrics,
    get_f5_at_different_thresholds,
)
from src.data import create_dataset
from src.utils import (
    get_reference_df_parquet,
    parse_predictions,
    filter_errors,
    generate_htmls_concurrently,
    visualize,
    convert_for_upload,
    CustomTrainer,
    parse_args,
)

## 🆕 Initialization

In [4]:
FIRST_PART = "first"
LAST_PART = "last"
MIDDLE_PART = "middle"
PART = MIDDLE_PART

In [5]:
MODEL_SIZE = "large"

In [6]:
MAX_LENGTH = 1024
WANDB_PROJECT = "Kaggle-PII"
USER_NAME = "shakleenishfar"
PROJECT_PATH = f"laplacesdemon43/{WANDB_PROJECT}"
EXPERIMENT = f"pii015_{PART}"
WANDB_NOTEBOOK_NAME = "pii-model-training.ipynb"
WANDB_NAME = f"DeBERTA-v3-{MODEL_SIZE}-{MAX_LENGTH}-{PART}"
WANDB_NOTES = f"""Training using DeBERTA-v3-{MODEL_SIZE}-{MAX_LENGTH} {PART} one-third negative samples. 
Included data from NBroad."""

In [7]:
config = SimpleNamespace(
    experiment=EXPERIMENT,
    threshold=0.95,
    o_weight=0.1,
    stride_artifact=f"{PROJECT_PATH}/processed_data:v0",
    raw_artifact=f"{PROJECT_PATH}/raw_data:v0",
    external_data_1="none",
    external_data_2="none",
    external_data_3="none",
    external_data_4="none",
    external_data_5="none",
    output_dir=f"model_dir/DeBERTA-V3-{MODEL_SIZE}-{MAX_LENGTH}-{PART}",
    inference_max_length=MAX_LENGTH,
    training_max_length=MAX_LENGTH,
    training_model_path=f"microsoft/deberta-v3-{MODEL_SIZE}",
    fp16=True,
    learning_rate=4e-5,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=2,
    report_to="wandb",
    evaluation_strategy="epoch",
    do_eval=True,
    save_total_limit=1,
    logging_steps=10,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,
    random_state=29,
)

In [8]:
wandb.login(key="0bf204609ea345c7c595565d736a9d62ca69f838")
wandb.init(
    project=WANDB_PROJECT,
    name=WANDB_NAME,
    notes=WANDB_NOTES,
    save_code=True,
    job_type="train",
    config=config,
)
config = wandb.config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshakleenishfar[0m ([33mlaplacesdemon43[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ishfar/.netrc


In [9]:
torch.manual_seed(config.random_state)

<torch._C.Generator at 0x7f2710300b10>

## 💾 Data Preparation

### Fetching Data

Getting data from Weights and Biases

In [10]:
stride_artifact = wandb.use_artifact(config.stride_artifact)
stride_artifact_dir = stride_artifact.download()
df = pd.read_parquet(stride_artifact_dir + "/stride_data.parquet")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


### Splitting Data

Into train and evaluation splits.

In [11]:
train_df = df[df.valid == False].reset_index(drop=True)
eval_df = df[df.valid == True].reset_index(drop=True)

print("Size of training dataset:", train_df.shape[0])
print("Size of validation dataset:", eval_df.shape[0])

Size of training dataset: 9516
Size of validation dataset: 1350


### Negative Sampling

Used to handle the extreme class imbalance in the data. Suggested by Valentin Warner.

* positive samples (contain relevant labels)

* negative samples (presumably contain entities that are possibly wrongly classified as entity)

In [12]:
negatives, positives = [], []

for _, row in train_df.iterrows():
    if any(row.labels != "O"):
        positives.append(row)
    else:
        negatives.append(row)
        
positives, negatives = pd.DataFrame(positives), pd.DataFrame(negatives)
print("Negative samples:", len(negatives))
print("Positive samples:", len(positives))

Negative samples: 5868
Positive samples: 3648


Take one third of the negative samples for downsampling.

In [13]:
if PART == FIRST_PART:
    negatives = negatives.iloc[: negatives.shape[0] // 3]
elif PART == MIDDLE_PART:
    negatives = negatives.iloc[negatives.shape[0] // 3 : 2 * negatives.shape[0] // 3]
elif PART == LAST_PART:
    negatives = negatives.iloc[2 * negatives.shape[0] // 3 :]
else:
    raise Exception(f"Undefined part: {PART}")

train_df = pd.concat([positives, negatives])
train_df = train_df.sample(frac=1, random_state=config.random_state)
print(f"Down sampled training: {len(train_df)}")
del positives, negatives

Down sampled training: 5604


### 🪙 Data Tokenization

In [14]:
reference_df = get_reference_df_parquet(config.raw_artifact)
all_labels = sorted(list(set(chain(*[x.tolist() for x in df.labels.values]))))
label2id = {l: i for i, l in enumerate(all_labels)}
id2label = {v: k for k, v in label2id.items()}
id2label

[34m[1mwandb[0m:   1 of 1 files downloaded.  


{0: 'B-EMAIL',
 1: 'B-ID_NUM',
 2: 'B-NAME_STUDENT',
 3: 'B-PHONE_NUM',
 4: 'B-STREET_ADDRESS',
 5: 'B-URL_PERSONAL',
 6: 'B-USERNAME',
 7: 'I-ID_NUM',
 8: 'I-NAME_STUDENT',
 9: 'I-PHONE_NUM',
 10: 'I-STREET_ADDRESS',
 11: 'I-URL_PERSONAL',
 12: 'O'}

In [15]:
tokenizer = AutoTokenizer.from_pretrained(config.training_model_path)
train_ds = create_dataset(train_df, tokenizer, config.training_max_length, label2id)
valid_ds = create_dataset(eval_df, tokenizer, config.inference_max_length, label2id)



Map (num_proc=6):   0%|          | 0/5604 [00:00<?, ? examples/s]

Map (num_proc=6):   0%|          | 0/1350 [00:00<?, ? examples/s]

## 🏋️ Training

In [16]:
model = AutoModelForTokenClassification.from_pretrained(
    config.training_model_path,
    num_labels=len(all_labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)
collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=16)

Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### PEFT (Parameter Efficient Finetuning)

In [17]:
# import peft
# from peft import (
#     get_peft_config,
#     PeftModel,
#     PeftConfig,
#     get_peft_model,
#     LoraConfig,
#     TaskType,
# )

In [18]:
# peft_config = LoraConfig(
#     r=128,  # Use larger 'r' value increase more parameters during training
#     bias='none',
#     inference_mode=False,
#     task_type=TaskType.SEQ_CLS,
#     # Only Use Output and Values Projection
#     target_modules=['query_proj', 'value_proj'],
# )

# # Load the PEFT model
# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

In [19]:
args = TrainingArguments(
    output_dir=config.output_dir,
    fp16=config.fp16,
    learning_rate=config.learning_rate,
    num_train_epochs=config.num_train_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    report_to=config.report_to,
    evaluation_strategy=config.evaluation_strategy,
    do_eval=config.do_eval,
    save_total_limit=config.save_total_limit,
    logging_steps=config.logging_steps,
    lr_scheduler_type=config.lr_scheduler_type,
    warmup_ratio=config.warmup_ratio,
    weight_decay=config.weight_decay,
)

Set "O" tokens to have a very small weight.

In [20]:
class_weights = torch.tensor([1.0] * 12 + [config.o_weight]).to("cuda")

In [21]:
trainer = CustomTrainer(
    model=model, 
    args=args, 
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    data_collator=collator, 
    tokenizer=tokenizer,
    compute_metrics=partial(
        compute_metrics,
        id2label=id2label,
        valid_ds=valid_ds,
        valid_df=reference_df,
        threshold=config.threshold,
    ),
    class_weights=class_weights
)

In [22]:
trainer.train()

  0%|          | 0/7005 [00:00<?, ?it/s]

{'loss': 2.8497, 'grad_norm': 26.395414352416992, 'learning_rate': 3.423680456490728e-07, 'epoch': 0.01}
{'loss': 2.8166, 'grad_norm': 28.661951065063477, 'learning_rate': 9.129814550641941e-07, 'epoch': 0.01}
{'loss': 2.6071, 'grad_norm': 21.40492820739746, 'learning_rate': 1.4835948644793155e-06, 'epoch': 0.02}
{'loss': 2.0437, 'grad_norm': 19.140548706054688, 'learning_rate': 2.0542082738944367e-06, 'epoch': 0.03}
{'loss': 1.6041, 'grad_norm': 15.281765937805176, 'learning_rate': 2.624821683309558e-06, 'epoch': 0.04}
{'loss': 0.9034, 'grad_norm': 258.8652648925781, 'learning_rate': 3.1954350927246792e-06, 'epoch': 0.04}
{'loss': 0.7795, 'grad_norm': 6.876041889190674, 'learning_rate': 3.7660485021398007e-06, 'epoch': 0.05}
{'loss': 0.6483, 'grad_norm': 1.7667365074157715, 'learning_rate': 4.336661911554922e-06, 'epoch': 0.06}
{'loss': 0.596, 'grad_norm': 5.993288516998291, 'learning_rate': 4.907275320970043e-06, 'epoch': 0.06}
{'loss': 0.4284, 'grad_norm': 8.128890037536621, 'learni

  0%|          | 0/169 [00:00<?, ?it/s]

{'eval_loss': 0.005666923243552446, 'eval_ents_p': 0.5343680709534369, 'eval_ents_r': 0.964, 'eval_ents_f5': 0.935084315773765, 'eval_ents_per_type_EMAIL_p': 0.9583333333333334, 'eval_ents_per_type_EMAIL_r': 1.0, 'eval_ents_per_type_EMAIL_f5': 0.998330550918197, 'eval_ents_per_type_ID_NUM_p': 0.35443037974683544, 'eval_ents_per_type_ID_NUM_r': 0.9333333333333333, 'eval_ents_per_type_ID_NUM_f5': 0.8781664656212304, 'eval_ents_per_type_NAME_STUDENT_p': 0.5407523510971787, 'eval_ents_per_type_NAME_STUDENT_r': 0.9663865546218487, 'eval_ents_per_type_NAME_STUDENT_f5': 0.937990170448604, 'eval_ents_per_type_PHONE_NUM_p': 0.7, 'eval_ents_per_type_PHONE_NUM_r': 1.0, 'eval_ents_per_type_PHONE_NUM_f5': 0.9837837837837837, 'eval_ents_per_type_STREET_ADDRESS_p': 0.7407407407407407, 'eval_ents_per_type_STREET_ADDRESS_r': 0.9090909090909091, 'eval_ents_per_type_STREET_ADDRESS_f5': 0.901213171577123, 'eval_ents_per_type_URL_PERSONAL_p': 0.47674418604651164, 'eval_ents_per_type_URL_PERSONAL_r': 1.0, '

  0%|          | 0/169 [00:00<?, ?it/s]

{'eval_loss': 0.007795552723109722, 'eval_ents_p': 0.36623963828183875, 'eval_ents_r': 0.972, 'eval_ents_f5': 0.9138641787806465, 'eval_ents_per_type_EMAIL_p': 0.92, 'eval_ents_per_type_EMAIL_r': 1.0, 'eval_ents_per_type_EMAIL_f5': 0.9966666666666667, 'eval_ents_per_type_ID_NUM_p': 0.5102040816326531, 'eval_ents_per_type_ID_NUM_r': 0.8333333333333334, 'eval_ents_per_type_ID_NUM_f5': 0.8135168961201501, 'eval_ents_per_type_NAME_STUDENT_p': 0.34637964774951074, 'eval_ents_per_type_NAME_STUDENT_r': 0.9915966386554622, 'eval_ents_per_type_NAME_STUDENT_f5': 0.9253041117925004, 'eval_ents_per_type_PHONE_NUM_p': 0.7241379310344828, 'eval_ents_per_type_PHONE_NUM_r': 1.0, 'eval_ents_per_type_PHONE_NUM_f5': 0.9855595667870036, 'eval_ents_per_type_STREET_ADDRESS_p': 0.7692307692307693, 'eval_ents_per_type_STREET_ADDRESS_r': 0.9090909090909091, 'eval_ents_per_type_STREET_ADDRESS_f5': 0.9027777777777776, 'eval_ents_per_type_URL_PERSONAL_p': 0.24404761904761904, 'eval_ents_per_type_URL_PERSONAL_r': 

  0%|          | 0/169 [00:00<?, ?it/s]

{'eval_loss': 0.004864082671701908, 'eval_ents_p': 0.6376811594202898, 'eval_ents_r': 0.968, 'eval_ents_f5': 0.9490911833471604, 'eval_ents_per_type_EMAIL_p': 0.92, 'eval_ents_per_type_EMAIL_r': 1.0, 'eval_ents_per_type_EMAIL_f5': 0.9966666666666667, 'eval_ents_per_type_ID_NUM_p': 0.6511627906976745, 'eval_ents_per_type_ID_NUM_r': 0.9333333333333333, 'eval_ents_per_type_ID_NUM_f5': 0.9180327868852459, 'eval_ents_per_type_NAME_STUDENT_p': 0.6181172291296625, 'eval_ents_per_type_NAME_STUDENT_r': 0.9747899159663865, 'eval_ents_per_type_NAME_STUDENT_f5': 0.9536256323777403, 'eval_ents_per_type_PHONE_NUM_p': 0.7777777777777778, 'eval_ents_per_type_PHONE_NUM_r': 1.0, 'eval_ents_per_type_PHONE_NUM_f5': 0.9891304347826085, 'eval_ents_per_type_STREET_ADDRESS_p': 0.8695652173913043, 'eval_ents_per_type_STREET_ADDRESS_r': 0.9090909090909091, 'eval_ents_per_type_STREET_ADDRESS_f5': 0.9075043630017451, 'eval_ents_per_type_URL_PERSONAL_p': 0.5540540540540541, 'eval_ents_per_type_URL_PERSONAL_r': 1.0

  0%|          | 0/169 [00:00<?, ?it/s]

{'eval_loss': 0.00813595112413168, 'eval_ents_p': 0.6291989664082688, 'eval_ents_r': 0.974, 'eval_ents_f5': 0.9538948320024107, 'eval_ents_per_type_EMAIL_p': 0.9583333333333334, 'eval_ents_per_type_EMAIL_r': 1.0, 'eval_ents_per_type_EMAIL_f5': 0.998330550918197, 'eval_ents_per_type_ID_NUM_p': 0.8181818181818182, 'eval_ents_per_type_ID_NUM_r': 0.9, 'eval_ents_per_type_ID_NUM_f5': 0.8965517241379312, 'eval_ents_per_type_NAME_STUDENT_p': 0.5899159663865546, 'eval_ents_per_type_NAME_STUDENT_r': 0.9831932773109243, 'eval_ents_per_type_NAME_STUDENT_f5': 0.9586134453781512, 'eval_ents_per_type_PHONE_NUM_p': 0.6363636363636364, 'eval_ents_per_type_PHONE_NUM_r': 1.0, 'eval_ents_per_type_PHONE_NUM_f5': 0.9784946236559142, 'eval_ents_per_type_STREET_ADDRESS_p': 0.8695652173913043, 'eval_ents_per_type_STREET_ADDRESS_r': 0.9090909090909091, 'eval_ents_per_type_STREET_ADDRESS_f5': 0.9075043630017451, 'eval_ents_per_type_URL_PERSONAL_p': 0.7192982456140351, 'eval_ents_per_type_URL_PERSONAL_r': 1.0, '

  0%|          | 0/169 [00:00<?, ?it/s]

{'eval_loss': 0.006850867066532373, 'eval_ents_p': 0.6949640287769784, 'eval_ents_r': 0.966, 'eval_ents_f5': 0.9517241379310344, 'eval_ents_per_type_EMAIL_p': 0.9583333333333334, 'eval_ents_per_type_EMAIL_r': 1.0, 'eval_ents_per_type_EMAIL_f5': 0.998330550918197, 'eval_ents_per_type_ID_NUM_p': 0.7714285714285715, 'eval_ents_per_type_ID_NUM_r': 0.9, 'eval_ents_per_type_ID_NUM_f5': 0.8942675159235669, 'eval_ents_per_type_NAME_STUDENT_p': 0.6647509578544061, 'eval_ents_per_type_NAME_STUDENT_r': 0.9719887955182073, 'eval_ents_per_type_NAME_STUDENT_f5': 0.95501217317667, 'eval_ents_per_type_PHONE_NUM_p': 0.6363636363636364, 'eval_ents_per_type_PHONE_NUM_r': 1.0, 'eval_ents_per_type_PHONE_NUM_f5': 0.9784946236559142, 'eval_ents_per_type_STREET_ADDRESS_p': 0.8695652173913043, 'eval_ents_per_type_STREET_ADDRESS_r': 0.9090909090909091, 'eval_ents_per_type_STREET_ADDRESS_f5': 0.9075043630017451, 'eval_ents_per_type_URL_PERSONAL_p': 0.8367346938775511, 'eval_ents_per_type_URL_PERSONAL_r': 1.0, 'e

TrainOutput(global_step=7005, training_loss=0.029327643484873153, metrics={'train_runtime': 5341.0209, 'train_samples_per_second': 5.246, 'train_steps_per_second': 1.312, 'train_loss': 0.029327643484873153, 'epoch': 5.0})

### Saving Model and Metrics locally

In [23]:
trainer.save_model(config.output_dir)
tokenizer.save_pretrained(config.output_dir)

('model_dir/DeBERTA-V3-large-1024-middle/tokenizer_config.json',
 'model_dir/DeBERTA-V3-large-1024-middle/special_tokens_map.json',
 'model_dir/DeBERTA-V3-large-1024-middle/spm.model',
 'model_dir/DeBERTA-V3-large-1024-middle/added_tokens.json',
 'model_dir/DeBERTA-V3-large-1024-middle/tokenizer.json')

## Determine Best Threshold

In [24]:
del tokenizer, model, collator, args, trainer

In [25]:
tokenizer = AutoTokenizer.from_pretrained(config.output_dir)
model = AutoModelForTokenClassification.from_pretrained(config.output_dir)
collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=16)
args = TrainingArguments(
    ".",
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    report_to="none",
)
trainer = CustomTrainer(
    model=model,
    args=args,
    data_collator=collator,
    tokenizer=tokenizer,
)
preds = trainer.predict(valid_ds)

  0%|          | 0/169 [00:00<?, ?it/s]

In [26]:
print("Computing final metrics...")

final_metrics = get_f5_at_different_thresholds(preds, id2label, valid_ds, reference_df)

wandb.log(final_metrics)
print(final_metrics)

Computing final metrics...
{'final_f5_at_0.55': 0.9408090195779691, 'final_f5_at_0.5': 0.9390429747028345, 'final_f5_at_0.6': 0.9406657018813314, 'final_f5_at_0.7': 0.9477208736017048, 'final_f5_at_0.65': 0.9460302961102229, 'final_f5_at_0.75': 0.9472883547577395, 'final_f5_at_0.85': 0.9501595502203312, 'final_f5_at_0.8': 0.9469282238442823, 'final_f5_at_0.9': 0.9491499696417729, 'final_f5_at_0.97': 0.9509313948205359, 'final_f5_at_0.99': 0.9531568228105903, 'final_f5_at_0.95': 0.9517241379310344}


In [27]:
# pick the best threshold from the final metrics and use it to generate preds_df
best_threshold = float(max(final_metrics, key=final_metrics.get).split("_")[-1])
print("best_threshold:", best_threshold)
wandb.config.best_threshold = best_threshold
preds_df = parse_predictions(
    preds.predictions, id2label, valid_ds, threshold=best_threshold
)

best_threshold: 0.99


## 📊 Data Visualization

In [28]:
# Prepare data to visualize errors and log them as a Weights & Biases table
print("Visualizing errors...")
grouped_preds = preds_df.groupby("eval_row")[
    ["document", "token", "label", "token_str"]
].agg(list)
viz_df = pd.merge(
    eval_df.reset_index(),
    grouped_preds,
    how="left",
    left_on="index",
    right_on="eval_row",
)
viz_df = filter_errors(viz_df, preds_df)
viz_df["pred_viz"] = generate_htmls_concurrently(
    viz_df,
    tokenizer,
    preds.predictions,
    id2label,
    valid_ds,
    threshold=best_threshold,
)
nlp = spacy.blank("en")
htmls = [visualize(row, nlp) for _, row in viz_df.iterrows()]
wandb_htmls = [wandb.Html(html) for html in htmls]
viz_df["gt_viz"] = wandb_htmls
viz_df.fillna("", inplace=True)
viz_df = convert_for_upload(viz_df)
errors_table = wandb.Table(dataframe=viz_df)
wandb.log({"errors_table": errors_table})

print("Experiment finished, test it out on the inference notebook!")

Visualizing errors...


  0%|          | 0/112 [00:00<?, ?it/s]



Experiment finished, test it out on the inference notebook!


In [29]:
wandb.finish()

VBox(children=(Label(value='3.932 MB of 11.518 MB uploaded\r'), FloatProgress(value=0.34142129323343134, max=1…



0,1
eval/ents_f5,▅▁▇██
eval/ents_p,▅▁▇▇█
eval/ents_per_type_EMAIL_f5,█▁▁██
eval/ents_per_type_EMAIL_p,█▁▁██
eval/ents_per_type_EMAIL_r,▁▁▁▁▁
eval/ents_per_type_ID_NUM_f5,▅▁█▇▆
eval/ents_per_type_ID_NUM_p,▁▃▅█▇
eval/ents_per_type_ID_NUM_r,█▁█▆▆
eval/ents_per_type_NAME_STUDENT_f5,▄▁▇█▇
eval/ents_per_type_NAME_STUDENT_p,▅▁▇▆█

0,1
eval/ents_f5,0.95172
eval/ents_p,0.69496
eval/ents_per_type_EMAIL_f5,0.99833
eval/ents_per_type_EMAIL_p,0.95833
eval/ents_per_type_EMAIL_r,1.0
eval/ents_per_type_ID_NUM_f5,0.89427
eval/ents_per_type_ID_NUM_p,0.77143
eval/ents_per_type_ID_NUM_r,0.9
eval/ents_per_type_NAME_STUDENT_f5,0.95501
eval/ents_per_type_NAME_STUDENT_p,0.66475
