In [302]:
import pandas as pd
import numpy as np
import dspy
from dotenv import load_dotenv
import ast
import torch

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import TrainingArguments
from transformers import Trainer
from typing import Any


In [303]:
def get_label_list(dataset: Dataset):
    label_set = set()
    for data in dataset:
        labels = data["ner_tags"]
        label_set.update(labels)
    return list(label_set)

In [304]:
dataset_name = "darrow-ai/LegalLensNER-SharedTask"

In [305]:
from datasets import load_dataset
dataset = load_dataset(dataset_name)

In [306]:
def safe_literal_eval(value):
    try:
        return ast.literal_eval(value)
    except (ValueError, SyntaxError):
        # Return None
        return None

In [307]:
dataset = dataset.map(
    lambda x: {
        "tokens": safe_literal_eval(x["tokens"]),
        "ner_tags": safe_literal_eval(x["ner_tags"]),
    }
)

In [308]:
def is_not_none(example):
    # Check that 'tokens' and 'ner_tags' fields are not None
    return example["tokens"] is not None and example["ner_tags"] is not None

# Filter the dataset to remove examples with None values in 'tokens' or 'ner_tags'
dataset = dataset.filter(is_not_none)

In [309]:
def non_matching_len(example):
    # Check that the len of tokens matches the len of NER tags
    return len(example["tokens"]) == len(example["ner_tags"])

dataset = dataset.filter(non_matching_len)

In [310]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 970
    })
})

In [311]:
dataset = dataset["train"].train_test_split(test_size=0.3, seed=1234)

In [336]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 679
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 291
    })
})

In [312]:
label_list = get_label_list(
    dataset["train"]
)  # Assuming 'train' split exists and contains the labels

In [313]:
label_list

['B-VIOLATED BY',
 'O',
 'I-VIOLATION',
 'I-VIOLATED BY',
 'B-LAW',
 'I-VIOLATED ON',
 'I-LAW',
 'B-VIOLATED ON',
 'B-VIOLATION']

In [314]:
label_to_id = {label: i for i, label in enumerate(label_list)}
id_to_label = {i: label for i, label in enumerate(label_list)}

In [315]:
test_df = pd.read_excel("NER_test_set.xlsx")
test_set = Dataset.from_pandas(test_df)

In [316]:
test_set

Dataset({
    features: ['id', 'tokens'],
    num_rows: 380
})

In [317]:
import evaluate


metric = evaluate.load("seqeval")

In [318]:
def tokenize_and_align_labels(examples, tokenizer):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        padding="max_length",
        is_split_into_words=True,
    )
    labels = []
    for i, example_labels in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        last_word_id = None
        for word_id in word_ids:
            if word_id is None:
                label_ids.append(-100)
            elif word_id != last_word_id:
                label_id = label_to_id.get(example_labels[word_id], -100)
                label_ids.append(label_id)
            else:
                label_ids.append(label_id)
            last_word_id = word_id
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [319]:
def tokenize_dataset(model_checkpoint: str, dataset: Dataset):
    
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space = True, use_fast = True)
    
    # Tokenization and alignment of labels
    tokenized_dataset = dataset.map(
        lambda x: tokenize_and_align_labels(x, tokenizer), batched=True
    )
    
    return tokenizer, tokenized_dataset

In [320]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    # Unpack nested dictionaries
    final_results = {}
    for key, value in results.items():
        if isinstance(value, dict):
            for n, v in value.items():
                final_results[f"{key}_{n}"] = v
        else:
            final_results[key] = value
    return final_results


In [321]:
def train_model(model_checkpoint: str, tokenized_dataset: Any, label_list: list):
    model = AutoModelForTokenClassification.from_pretrained(
        model_checkpoint, num_labels=len(label_list), id2label = id_to_label, label2id = label_to_id
    )
    
    training_args = TrainingArguments(model_checkpoint,
                         evaluation_strategy = "epoch",
                         save_strategy="epoch",
                         learning_rate = 2e-5,
                         num_train_epochs=10,
                         warmup_steps=500,
                         weight_decay=0.01)
    
    
    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        compute_metrics=compute_metrics,
    )
    
    trainer.train()
    
    return trainer
    

In [322]:
def evaluate_trained_model(trainer: Trainer):
    eval_result = trainer.evaluate()
    
    evaluation_results = {k: v for k, v in eval_result.items()}
    
    return evaluation_results

In [323]:
# Function to run inference on a single example
def predict_tags(example, finetuned_tokenizer, finetuned_model):
    tokens = ast.literal_eval(example["tokens"])

    # Tokenize the input tokens
    inputs = finetuned_tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, padding=True)

    # Run inference
    with torch.no_grad():
        outputs = finetuned_model(**inputs)
        logits = outputs.logits

    # Convert logits to predicted tags
    predictions = torch.argmax(logits, dim=-1)
    predicted_tags = [finetuned_model.config.id2label[p.item()] for p in predictions[0]]

    # Align the predicted tags with the input tokens
    word_ids = inputs.word_ids()
    aligned_labels = []
    for i, word_id in enumerate(word_ids):
        if word_id is None:
            continue
        if word_id != word_ids[i - 1]:
            aligned_labels.append(predicted_tags[i])

    # Return the predicted tags
    example["predicted_tags"] = aligned_labels
    return example

## Distilbert Cased

In [199]:
model_checkpoint = "distilbert/distilbert-base-uncased"

In [200]:
tokenizer, tokenized_dataset = tokenize_dataset(model_checkpoint, dataset)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 291/291 [00:00<00:00, 3468.83 examples/s]


In [201]:
trained_model = train_model(model_checkpoint, tokenized_dataset, label_list)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, msg_start, len(result))

 10%|â–ˆ         | 85/850 [00:45<05:40,  2.25it/s]

{'eval_loss': 1.09114670753479, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.0, 'eval_VIOLATION_recall': 0.0, 'eval_VIOLATION_f1': 0.0, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.0, 'eval_overall_recall': 0.0, 'eval_overall_f1': 0.0, 'eval_overall_accuracy': 0.7714780767194493, 'eval_runtime': 5.593, 'eval_samples_per_second': 52.029, 'eval_steps_per_second': 6.615, 'epoch': 1.0}


  _warn_prf(average, modifier, msg_start, len(result))

 20%|â–ˆâ–ˆ        | 170/850 [01:32<05:03,  2.24it/s]

{'eval_loss': 0.399384081363678, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.16806722689075632, 'eval_VIOLATION_recall': 0.19753086419753085, 'eval_VIOLATION_f1': 0.1816118047673099, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.16806722689075632, 'eval_overall_recall': 0.13029315960912052, 'eval_overall_f1': 0.14678899082568808, 'eval_overall_accuracy': 0.8886749771381817, 'eval_runtime': 5.5505, 'eval_samples_per_second': 52.428, 'eval_steps_per_second': 6.666, 'epoch': 2.0}


  _warn_prf(average, modifier, msg_start, len(result))

 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [02:18<04:23,  2.25it/s]

{'eval_loss': 0.28697431087493896, 'eval_LAW_precision': 0.13793103448275862, 'eval_LAW_recall': 0.04040404040404041, 'eval_LAW_f1': 0.0625, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.3829787234042553, 'eval_VIOLATION_recall': 0.4, 'eval_VIOLATION_f1': 0.3913043478260869, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.3672566371681416, 'eval_overall_recall': 0.2703583061889251, 'eval_overall_f1': 0.31144465290806755, 'eval_overall_accuracy': 0.9100447610338355, 'eval_runtime': 5.5616, 'eval_samples_per_second': 52.323, 'eval_steps_per_second': 6.653, 'epoch': 3.0}


 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [02:58<03:46,  2.25it/s]
 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [03:04<03:46,  2.25it/s]

{'eval_loss': 0.22157905995845795, 'eval_LAW_precision': 0.7051282051282052, 'eval_LAW_recall': 0.5555555555555556, 'eval_LAW_f1': 0.6214689265536725, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.4, 'eval_VIOLATED BY_recall': 0.14035087719298245, 'eval_VIOLATED BY_f1': 0.20779220779220778, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.39285714285714285, 'eval_VIOLATED ON_recall': 0.20754716981132076, 'eval_VIOLATED ON_f1': 0.271604938271605, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.4737903225806452, 'eval_VIOLATION_recall': 0.5802469135802469, 'eval_VIOLATION_f1': 0.5216426193118757, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.4967845659163987, 'eval_overall_recall': 0.503257328990228, 'eval_overall_f1': 0.5, 'eval_overall_accuracy': 0.9292005583096693, 'eval_runtime': 5.537, 'eval_samples_per_second': 52.555, 'eval_steps_per_second': 6.682, 'epoch': 4.0}


 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [03:44<03:08,  2.25it/s]
 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [03:50<03:08,  2.25it/s]

{'eval_loss': 0.20696575939655304, 'eval_LAW_precision': 0.7093023255813954, 'eval_LAW_recall': 0.6161616161616161, 'eval_LAW_f1': 0.6594594594594594, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.4339622641509434, 'eval_VIOLATED BY_recall': 0.40350877192982454, 'eval_VIOLATED BY_f1': 0.41818181818181815, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.24489795918367346, 'eval_VIOLATED ON_recall': 0.22641509433962265, 'eval_VIOLATED ON_f1': 0.23529411764705882, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.4847870182555781, 'eval_VIOLATION_recall': 0.5901234567901235, 'eval_VIOLATION_f1': 0.532293986636971, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.4919236417033774, 'eval_overall_recall': 0.5456026058631922, 'eval_overall_f1': 0.5173745173745173, 'eval_overall_accuracy': 0.9336766616932185, 'eval_runtime': 5.5276, 'eval_samples_per_second': 52.645, 'eval_steps_per_second': 6.694, 'epoch': 5.0}


 59%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰    | 500/850 [04:26<02:41,  2.17it/s]

{'loss': 0.5834, 'grad_norm': 0.767694890499115, 'learning_rate': 2e-05, 'epoch': 5.88}


 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [04:31<02:35,  2.18it/s]
 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [04:36<02:35,  2.18it/s]

{'eval_loss': 0.21254582703113556, 'eval_LAW_precision': 0.8068181818181818, 'eval_LAW_recall': 0.7171717171717171, 'eval_LAW_f1': 0.7593582887700534, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5434782608695652, 'eval_VIOLATED BY_recall': 0.43859649122807015, 'eval_VIOLATED BY_f1': 0.48543689320388345, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.45454545454545453, 'eval_VIOLATED ON_recall': 0.2830188679245283, 'eval_VIOLATED ON_f1': 0.3488372093023256, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5095541401273885, 'eval_VIOLATION_recall': 0.5925925925925926, 'eval_VIOLATION_f1': 0.547945205479452, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5501567398119123, 'eval_overall_recall': 0.5716612377850163, 'eval_overall_f1': 0.560702875399361, 'eval_overall_accuracy': 0.9354574770178563, 'eval_runtime': 5.6007, 'eval_samples_per_second': 51.958, 'eval_steps_per_second': 6.606, 'epoch': 6.0}


 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [05:17<01:52,  2.26it/s]
 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [05:22<01:52,  2.26it/s]

{'eval_loss': 0.23013223707675934, 'eval_LAW_precision': 0.8089887640449438, 'eval_LAW_recall': 0.7272727272727273, 'eval_LAW_f1': 0.7659574468085106, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5777777777777777, 'eval_VIOLATED BY_recall': 0.45614035087719296, 'eval_VIOLATED BY_f1': 0.5098039215686275, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.3902439024390244, 'eval_VIOLATED ON_recall': 0.3018867924528302, 'eval_VIOLATED ON_f1': 0.3404255319148936, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5611814345991561, 'eval_VIOLATION_recall': 0.6567901234567901, 'eval_VIOLATION_f1': 0.6052332195676906, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5855161787365177, 'eval_overall_recall': 0.6188925081433225, 'eval_overall_f1': 0.601741884402217, 'eval_overall_accuracy': 0.9421475670212254, 'eval_runtime': 5.5382, 'eval_samples_per_second': 52.544, 'eval_steps_per_second': 6.681, 'epoch': 7.0}


 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [06:04<01:24,  2.01it/s]
 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [06:09<01:24,  2.01it/s]

{'eval_loss': 0.23650074005126953, 'eval_LAW_precision': 0.797979797979798, 'eval_LAW_recall': 0.797979797979798, 'eval_LAW_f1': 0.7979797979797979, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5357142857142857, 'eval_VIOLATED BY_recall': 0.5263157894736842, 'eval_VIOLATED BY_f1': 0.5309734513274336, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.45, 'eval_VIOLATED ON_recall': 0.33962264150943394, 'eval_VIOLATED ON_f1': 0.3870967741935484, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5868131868131868, 'eval_VIOLATION_recall': 0.6592592592592592, 'eval_VIOLATION_f1': 0.6209302325581396, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.6061538461538462, 'eval_overall_recall': 0.6416938110749185, 'eval_overall_f1': 0.6234177215189873, 'eval_overall_accuracy': 0.9423400875968619, 'eval_runtime': 5.5941, 'eval_samples_per_second': 52.019, 'eval_steps_per_second': 6.614, 'epoch': 8.0}


 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [06:50<00:37,  2.24it/s]
 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [06:56<00:37,  2.24it/s]

{'eval_loss': 0.24343740940093994, 'eval_LAW_precision': 0.7835051546391752, 'eval_LAW_recall': 0.7676767676767676, 'eval_LAW_f1': 0.7755102040816326, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5740740740740741, 'eval_VIOLATED BY_recall': 0.543859649122807, 'eval_VIOLATED BY_f1': 0.5585585585585585, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.35294117647058826, 'eval_VIOLATED ON_recall': 0.33962264150943394, 'eval_VIOLATED ON_f1': 0.34615384615384615, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5823045267489712, 'eval_VIOLATION_recall': 0.6987654320987654, 'eval_VIOLATION_f1': 0.6352413019079685, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5930232558139535, 'eval_overall_recall': 0.6644951140065146, 'eval_overall_f1': 0.6267281105990783, 'eval_overall_accuracy': 0.9441690330654089, 'eval_runtime': 5.6302, 'eval_samples_per_second': 51.685, 'eval_steps_per_second': 6.572, 'epoch': 9.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [07:36<00:00,  2.25it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [07:43<00:00,  2.25it/s]

{'eval_loss': 0.2476310133934021, 'eval_LAW_precision': 0.7894736842105263, 'eval_LAW_recall': 0.7575757575757576, 'eval_LAW_f1': 0.7731958762886598, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5614035087719298, 'eval_VIOLATED BY_recall': 0.5614035087719298, 'eval_VIOLATED BY_f1': 0.5614035087719298, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.37735849056603776, 'eval_VIOLATED ON_recall': 0.37735849056603776, 'eval_VIOLATED ON_f1': 0.3773584905660377, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5533199195171026, 'eval_VIOLATION_recall': 0.6790123456790124, 'eval_VIOLATION_f1': 0.6097560975609756, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5726495726495726, 'eval_overall_recall': 0.6547231270358306, 'eval_overall_f1': 0.6109422492401215, 'eval_overall_accuracy': 0.9436877316263176, 'eval_runtime': 5.5198, 'eval_samples_per_second': 52.72, 'eval_steps_per_second': 6.703, 'epoch': 10.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [07:44<00:00,  1.83it/s]

{'train_runtime': 464.6205, 'train_samples_per_second': 14.614, 'train_steps_per_second': 1.829, 'train_loss': 0.36275078044218173, 'epoch': 10.0}





In [202]:
evaluation_results = evaluate_trained_model(trained_model)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:05<00:00,  6.76it/s]


In [203]:
evaluation_results

{'eval_loss': 0.2476310133934021,
 'eval_LAW_precision': 0.7894736842105263,
 'eval_LAW_recall': 0.7575757575757576,
 'eval_LAW_f1': 0.7731958762886598,
 'eval_LAW_number': 99,
 'eval_VIOLATED BY_precision': 0.5614035087719298,
 'eval_VIOLATED BY_recall': 0.5614035087719298,
 'eval_VIOLATED BY_f1': 0.5614035087719298,
 'eval_VIOLATED BY_number': 57,
 'eval_VIOLATED ON_precision': 0.37735849056603776,
 'eval_VIOLATED ON_recall': 0.37735849056603776,
 'eval_VIOLATED ON_f1': 0.3773584905660377,
 'eval_VIOLATED ON_number': 53,
 'eval_VIOLATION_precision': 0.5533199195171026,
 'eval_VIOLATION_recall': 0.6790123456790124,
 'eval_VIOLATION_f1': 0.6097560975609756,
 'eval_VIOLATION_number': 405,
 'eval_overall_precision': 0.5726495726495726,
 'eval_overall_recall': 0.6547231270358306,
 'eval_overall_f1': 0.6109422492401215,
 'eval_overall_accuracy': 0.9436877316263176,
 'eval_runtime': 6.0683,
 'eval_samples_per_second': 47.954,
 'eval_steps_per_second': 6.097,
 'epoch': 10.0}

In [204]:
trained_model.save_model(output_dir="legal_lens_finetuned_models/distilbert_finetuned/")

In [205]:
tokenizer.save_pretrained(save_directory="legal_lens_finetuned_tokenizers/distilbert_finetuned/")

('legal_lens_finetuned_tokenizers/distilbert_finetuned/tokenizer_config.json',
 'legal_lens_finetuned_tokenizers/distilbert_finetuned/special_tokens_map.json',
 'legal_lens_finetuned_tokenizers/distilbert_finetuned/vocab.txt',
 'legal_lens_finetuned_tokenizers/distilbert_finetuned/added_tokens.json',
 'legal_lens_finetuned_tokenizers/distilbert_finetuned/tokenizer.json')

In [227]:
finetuned_tokenizer = AutoTokenizer.from_pretrained("legal_lens_finetuned_tokenizers/distilbert_finetuned/")
finetuned_model = AutoModelForTokenClassification.from_pretrained("legal_lens_finetuned_models/distilbert_finetuned/")

In [230]:
# Apply the function to the entire dataset
test_results = test_set.map(
    lambda x: predict_tags(x, finetuned_tokenizer, finetuned_model))

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 380/380 [00:10<00:00, 37.98 examples/s]


In [231]:
test_results

Dataset({
    features: ['id', 'tokens', 'predicted_tags'],
    num_rows: 380
})

## BERT Uncased

In [232]:
model_checkpoint = "google-bert/bert-base-uncased"

In [210]:
tokenizer, tokenized_dataset = tokenize_dataset(model_checkpoint, dataset)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 679/679 [00:00<00:00, 3851.41 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 291/291 [00:00<00:00, 3849.82 examples/s]


In [211]:
trained_model = train_model(model_checkpoint, tokenized_dataset, label_list)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))

 10%|â–ˆ         | 85/850 [01:21<10:29,  1.22it/s]

{'eval_loss': 0.8390417695045471, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.0, 'eval_VIOLATION_recall': 0.0, 'eval_VIOLATION_f1': 0.0, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.0, 'eval_overall_recall': 0.0, 'eval_overall_f1': 0.0, 'eval_overall_accuracy': 0.7714780767194493, 'eval_runtime': 9.306, 'eval_samples_per_second': 31.27, 'eval_steps_per_second': 3.976, 'epoch': 1.0}


  _warn_prf(average, modifier, msg_start, len(result))

 20%|â–ˆâ–ˆ        | 170/850 [02:44<09:16,  1.22it/s]

{'eval_loss': 0.38071727752685547, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.1673228346456693, 'eval_VIOLATION_recall': 0.20987654320987653, 'eval_VIOLATION_f1': 0.18619934282584885, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.1673228346456693, 'eval_overall_recall': 0.13843648208469056, 'eval_overall_f1': 0.15151515151515155, 'eval_overall_accuracy': 0.8921884776435481, 'eval_runtime': 9.2588, 'eval_samples_per_second': 31.43, 'eval_steps_per_second': 3.996, 'epoch': 2.0}


  _warn_prf(average, modifier, msg_start, len(result))

 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [04:09<08:07,  1.22it/s]

{'eval_loss': 0.26429009437561035, 'eval_LAW_precision': 0.2982456140350877, 'eval_LAW_recall': 0.1717171717171717, 'eval_LAW_f1': 0.21794871794871795, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.40461215932914046, 'eval_VIOLATION_recall': 0.4765432098765432, 'eval_VIOLATION_f1': 0.43764172335600904, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.39325842696629215, 'eval_overall_recall': 0.34201954397394135, 'eval_overall_f1': 0.3658536585365854, 'eval_overall_accuracy': 0.9188525773692063, 'eval_runtime': 9.2694, 'eval_samples_per_second': 31.394, 'eval_steps_per_second': 3.992, 'epoch': 3.0}


 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [05:24<06:57,  1.22it/s]
 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [05:33<06:57,  1.22it/s]

{'eval_loss': 0.2230903059244156, 'eval_LAW_precision': 0.6966292134831461, 'eval_LAW_recall': 0.6262626262626263, 'eval_LAW_f1': 0.6595744680851064, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.4426229508196721, 'eval_VIOLATED BY_recall': 0.47368421052631576, 'eval_VIOLATED BY_f1': 0.4576271186440678, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.2916666666666667, 'eval_VIOLATED ON_recall': 0.39622641509433965, 'eval_VIOLATED ON_f1': 0.3360000000000001, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.4155844155844156, 'eval_VIOLATION_recall': 0.5530864197530864, 'eval_VIOLATION_f1': 0.4745762711864407, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.43889618922470436, 'eval_overall_recall': 0.5439739413680782, 'eval_overall_f1': 0.48581818181818187, 'eval_overall_accuracy': 0.9341579631323098, 'eval_runtime': 9.2441, 'eval_samples_per_second': 31.48, 'eval_steps_per_second': 4.003, 'epoch': 4.0}


 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [06:47<05:49,  1.22it/s]
 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [06:56<05:49,  1.22it/s]

{'eval_loss': 0.21054990589618683, 'eval_LAW_precision': 0.7931034482758621, 'eval_LAW_recall': 0.696969696969697, 'eval_LAW_f1': 0.7419354838709677, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.6, 'eval_VIOLATED BY_recall': 0.5263157894736842, 'eval_VIOLATED BY_f1': 0.5607476635514018, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.4, 'eval_VIOLATED ON_recall': 0.37735849056603776, 'eval_VIOLATED ON_f1': 0.38834951456310685, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5576923076923077, 'eval_VIOLATION_recall': 0.6444444444444445, 'eval_VIOLATION_f1': 0.5979381443298969, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5801526717557252, 'eval_overall_recall': 0.6188925081433225, 'eval_overall_f1': 0.5988967691095352, 'eval_overall_accuracy': 0.9415700052943158, 'eval_runtime': 9.2463, 'eval_samples_per_second': 31.472, 'eval_steps_per_second': 4.002, 'epoch': 5.0}


 59%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰    | 500/850 [08:02<04:56,  1.18it/s]

{'loss': 0.5422, 'grad_norm': 0.7910398244857788, 'learning_rate': 2e-05, 'epoch': 5.88}


 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [08:10<04:41,  1.21it/s]
 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [08:20<04:41,  1.21it/s]

{'eval_loss': 0.22337868809700012, 'eval_LAW_precision': 0.6605504587155964, 'eval_LAW_recall': 0.7272727272727273, 'eval_LAW_f1': 0.6923076923076923, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5166666666666667, 'eval_VIOLATED BY_recall': 0.543859649122807, 'eval_VIOLATED BY_f1': 0.52991452991453, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.39344262295081966, 'eval_VIOLATED ON_recall': 0.4528301886792453, 'eval_VIOLATED ON_f1': 0.42105263157894735, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5211267605633803, 'eval_VIOLATION_recall': 0.6395061728395062, 'eval_VIOLATION_f1': 0.5742793791574279, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.530949105914718, 'eval_overall_recall': 0.6286644951140065, 'eval_overall_f1': 0.575689783743475, 'eval_overall_accuracy': 0.940174231120951, 'eval_runtime': 9.3372, 'eval_samples_per_second': 31.166, 'eval_steps_per_second': 3.963, 'epoch': 6.0}


 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [09:34<03:28,  1.22it/s]
 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [09:43<03:28,  1.22it/s]

{'eval_loss': 0.25568267703056335, 'eval_LAW_precision': 0.7368421052631579, 'eval_LAW_recall': 0.7070707070707071, 'eval_LAW_f1': 0.7216494845360824, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.6274509803921569, 'eval_VIOLATED BY_recall': 0.5614035087719298, 'eval_VIOLATED BY_f1': 0.5925925925925926, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.4126984126984127, 'eval_VIOLATED ON_recall': 0.49056603773584906, 'eval_VIOLATED ON_f1': 0.44827586206896547, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5583657587548638, 'eval_VIOLATION_recall': 0.7086419753086419, 'eval_VIOLATION_f1': 0.6245919477693144, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.573997233748271, 'eval_overall_recall': 0.6758957654723127, 'eval_overall_f1': 0.6207928197456993, 'eval_overall_accuracy': 0.940511142128315, 'eval_runtime': 9.2872, 'eval_samples_per_second': 31.333, 'eval_steps_per_second': 3.984, 'epoch': 7.0}


 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [10:58<02:19,  1.22it/s]
 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [11:08<02:19,  1.22it/s]

{'eval_loss': 0.24820150434970856, 'eval_LAW_precision': 0.7333333333333333, 'eval_LAW_recall': 0.7777777777777778, 'eval_LAW_f1': 0.7549019607843137, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.4666666666666667, 'eval_VIOLATED BY_recall': 0.6140350877192983, 'eval_VIOLATED BY_f1': 0.5303030303030304, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.3333333333333333, 'eval_VIOLATED ON_recall': 0.41509433962264153, 'eval_VIOLATED ON_f1': 0.36974789915966383, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.536144578313253, 'eval_VIOLATION_recall': 0.6592592592592592, 'eval_VIOLATION_f1': 0.5913621262458472, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5389784946236559, 'eval_overall_recall': 0.6530944625407166, 'eval_overall_f1': 0.5905743740795287, 'eval_overall_accuracy': 0.9414256148625885, 'eval_runtime': 9.2948, 'eval_samples_per_second': 31.308, 'eval_steps_per_second': 3.981, 'epoch': 8.0}


 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [12:22<01:09,  1.22it/s]
 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [12:31<01:09,  1.22it/s]

{'eval_loss': 0.25872042775154114, 'eval_LAW_precision': 0.7604166666666666, 'eval_LAW_recall': 0.7373737373737373, 'eval_LAW_f1': 0.7487179487179487, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.6071428571428571, 'eval_VIOLATED BY_recall': 0.5964912280701754, 'eval_VIOLATED BY_f1': 0.6017699115044247, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.4107142857142857, 'eval_VIOLATED ON_recall': 0.4339622641509434, 'eval_VIOLATED ON_f1': 0.42201834862385323, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5494949494949495, 'eval_VIOLATION_recall': 0.671604938271605, 'eval_VIOLATION_f1': 0.6044444444444445, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5718349928876245, 'eval_overall_recall': 0.6547231270358306, 'eval_overall_f1': 0.6104783599088839, 'eval_overall_accuracy': 0.9439765124897723, 'eval_runtime': 9.2728, 'eval_samples_per_second': 31.382, 'eval_steps_per_second': 3.99, 'epoch': 9.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [13:47<00:00,  1.22it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [13:59<00:00,  1.22it/s]

{'eval_loss': 0.2607121467590332, 'eval_LAW_precision': 0.7708333333333334, 'eval_LAW_recall': 0.7474747474747475, 'eval_LAW_f1': 0.758974358974359, 'eval_LAW_number': 99, 'eval_VIOLATED BY_precision': 0.5666666666666667, 'eval_VIOLATED BY_recall': 0.5964912280701754, 'eval_VIOLATED BY_f1': 0.5811965811965812, 'eval_VIOLATED BY_number': 57, 'eval_VIOLATED ON_precision': 0.38596491228070173, 'eval_VIOLATED ON_recall': 0.41509433962264153, 'eval_VIOLATED ON_f1': 0.4, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5496957403651116, 'eval_VIOLATION_recall': 0.6691358024691358, 'eval_VIOLATION_f1': 0.6035634743875278, 'eval_VIOLATION_number': 405, 'eval_overall_precision': 0.5679886685552408, 'eval_overall_recall': 0.6530944625407166, 'eval_overall_f1': 0.6075757575757575, 'eval_overall_accuracy': 0.9437358617702267, 'eval_runtime': 9.2966, 'eval_samples_per_second': 31.302, 'eval_steps_per_second': 3.98, 'epoch': 10.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [14:00<00:00,  1.01it/s]

{'train_runtime': 840.7453, 'train_samples_per_second': 8.076, 'train_steps_per_second': 1.011, 'train_loss': 0.3335383437661564, 'epoch': 10.0}





In [212]:
evaluation_results = evaluate_trained_model(trained_model)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:09<00:00,  4.00it/s]


In [213]:
evaluation_results

{'eval_loss': 0.2607121467590332,
 'eval_LAW_precision': 0.7708333333333334,
 'eval_LAW_recall': 0.7474747474747475,
 'eval_LAW_f1': 0.758974358974359,
 'eval_LAW_number': 99,
 'eval_VIOLATED BY_precision': 0.5666666666666667,
 'eval_VIOLATED BY_recall': 0.5964912280701754,
 'eval_VIOLATED BY_f1': 0.5811965811965812,
 'eval_VIOLATED BY_number': 57,
 'eval_VIOLATED ON_precision': 0.38596491228070173,
 'eval_VIOLATED ON_recall': 0.41509433962264153,
 'eval_VIOLATED ON_f1': 0.4,
 'eval_VIOLATED ON_number': 53,
 'eval_VIOLATION_precision': 0.5496957403651116,
 'eval_VIOLATION_recall': 0.6691358024691358,
 'eval_VIOLATION_f1': 0.6035634743875278,
 'eval_VIOLATION_number': 405,
 'eval_overall_precision': 0.5679886685552408,
 'eval_overall_recall': 0.6530944625407166,
 'eval_overall_f1': 0.6075757575757575,
 'eval_overall_accuracy': 0.9437358617702267,
 'eval_runtime': 10.6067,
 'eval_samples_per_second': 27.436,
 'eval_steps_per_second': 3.488,
 'epoch': 10.0}

In [233]:
tokenizer.save_pretrained(save_directory="legal_lens_finetuned_tokenizers/bert_uncased_finetuned/")
trained_model.save_model(output_dir="legal_lens_finetuned_models/bert_uncased_finetuned/")

In [234]:
finetuned_tokenizer = AutoTokenizer.from_pretrained("legal_lens_finetuned_tokenizers/bert_uncased_finetuned/")
finetuned_model = AutoModelForTokenClassification.from_pretrained("legal_lens_finetuned_models/bert_uncased_finetuned/")

In [235]:
# Apply the function to the entire dataset
test_results = test_set.map(
    lambda x: predict_tags(x, finetuned_tokenizer, finetuned_model))

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 380/380 [00:16<00:00, 22.80 examples/s]


## Bert Cased

In [236]:
model_checkpoint = "google-bert/bert-base-cased"

In [237]:
tokenizer, tokenized_dataset = tokenize_dataset(model_checkpoint, dataset)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 679/679 [00:00<00:00, 3720.00 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 291/291 [00:00<00:00, 3991.70 examples/s]


In [238]:
trained_model = train_model(model_checkpoint, tokenized_dataset, label_list)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, msg_start, len(result))

 10%|â–ˆ         | 85/850 [01:22<10:17,  1.24it/s]

{'eval_loss': 0.8149062991142273, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.0, 'eval_VIOLATION_recall': 0.0, 'eval_VIOLATION_f1': 0.0, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.0, 'eval_overall_recall': 0.0, 'eval_overall_f1': 0.0, 'eval_overall_accuracy': 0.7665760105036106, 'eval_runtime': 9.2376, 'eval_samples_per_second': 31.502, 'eval_steps_per_second': 4.005, 'epoch': 1.0}


  _warn_prf(average, modifier, msg_start, len(result))

 20%|â–ˆâ–ˆ        | 170/850 [02:45<09:14,  1.23it/s]

{'eval_loss': 0.3455027639865875, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.2629695885509839, 'eval_VIOLATION_recall': 0.30246913580246915, 'eval_VIOLATION_f1': 0.2813397129186602, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.2620320855614973, 'eval_overall_recall': 0.20704225352112676, 'eval_overall_f1': 0.2313139260424862, 'eval_overall_accuracy': 0.8946356560067523, 'eval_runtime': 9.2057, 'eval_samples_per_second': 31.611, 'eval_steps_per_second': 4.019, 'epoch': 2.0}


 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [04:00<08:04,  1.23it/s]
 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [04:09<08:04,  1.23it/s]

{'eval_loss': 0.23710928857326508, 'eval_LAW_precision': 0.5652173913043478, 'eval_LAW_recall': 0.3611111111111111, 'eval_LAW_f1': 0.44067796610169496, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.44976816074188564, 'eval_VIOLATION_recall': 0.5987654320987654, 'eval_VIOLATION_f1': 0.5136804942630184, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.4558011049723757, 'eval_overall_recall': 0.4647887323943662, 'eval_overall_f1': 0.46025104602510464, 'eval_overall_accuracy': 0.9224889805870768, 'eval_runtime': 9.3588, 'eval_samples_per_second': 31.094, 'eval_steps_per_second': 3.954, 'epoch': 3.0}


 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [05:24<06:57,  1.22it/s]
 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [05:33<06:57,  1.22it/s]

{'eval_loss': 0.19310948252677917, 'eval_LAW_precision': 0.7804878048780488, 'eval_LAW_recall': 0.5925925925925926, 'eval_LAW_f1': 0.6736842105263158, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.5714285714285714, 'eval_VIOLATED BY_recall': 0.38095238095238093, 'eval_VIOLATED BY_f1': 0.4571428571428571, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.35, 'eval_VIOLATED ON_recall': 0.2641509433962264, 'eval_VIOLATED ON_f1': 0.30107526881720426, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5524475524475524, 'eval_VIOLATION_recall': 0.6502057613168725, 'eval_VIOLATION_f1': 0.5973534971644612, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.5679347826086957, 'eval_overall_recall': 0.5887323943661972, 'eval_overall_f1': 0.5781466113416321, 'eval_overall_accuracy': 0.9381037231548345, 'eval_runtime': 9.3577, 'eval_samples_per_second': 31.097, 'eval_steps_per_second': 3.954, 'epoch': 4.0}


 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [06:49<05:46,  1.22it/s]
 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [06:58<05:46,  1.22it/s]

{'eval_loss': 0.21959036588668823, 'eval_LAW_precision': 0.7954545454545454, 'eval_LAW_recall': 0.6481481481481481, 'eval_LAW_f1': 0.7142857142857143, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.5490196078431373, 'eval_VIOLATED BY_recall': 0.4444444444444444, 'eval_VIOLATED BY_f1': 0.49122807017543857, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.3484848484848485, 'eval_VIOLATED ON_recall': 0.4339622641509434, 'eval_VIOLATED ON_f1': 0.38655462184873957, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5989110707803993, 'eval_VIOLATION_recall': 0.6790123456790124, 'eval_VIOLATION_f1': 0.6364513018322082, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.5965608465608465, 'eval_overall_recall': 0.6352112676056338, 'eval_overall_f1': 0.6152796725784447, 'eval_overall_accuracy': 0.9436368751758417, 'eval_runtime': 9.2519, 'eval_samples_per_second': 31.453, 'eval_steps_per_second': 3.999, 'epoch': 5.0}


 59%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰    | 500/850 [08:04<04:56,  1.18it/s]

{'loss': 0.5112, 'grad_norm': 1.2425944805145264, 'learning_rate': 2e-05, 'epoch': 5.88}


 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [08:12<04:39,  1.22it/s]
 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [08:21<04:39,  1.22it/s]

{'eval_loss': 0.20248673856258392, 'eval_LAW_precision': 0.7889908256880734, 'eval_LAW_recall': 0.7962962962962963, 'eval_LAW_f1': 0.7926267281105991, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.47540983606557374, 'eval_VIOLATED BY_recall': 0.4603174603174603, 'eval_VIOLATED BY_f1': 0.4677419354838709, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.36363636363636365, 'eval_VIOLATED ON_recall': 0.37735849056603776, 'eval_VIOLATED ON_f1': 0.37037037037037035, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.5429497568881686, 'eval_VIOLATION_recall': 0.6893004115226338, 'eval_VIOLATION_f1': 0.6074342701722576, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.5581947743467933, 'eval_overall_recall': 0.6619718309859155, 'eval_overall_f1': 0.6056701030927836, 'eval_overall_accuracy': 0.942792835037044, 'eval_runtime': 9.2137, 'eval_samples_per_second': 31.583, 'eval_steps_per_second': 4.016, 'epoch': 6.0}


 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [09:34<03:22,  1.26it/s]
 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [09:44<03:22,  1.26it/s]

{'eval_loss': 0.268388956785202, 'eval_LAW_precision': 0.7522123893805309, 'eval_LAW_recall': 0.7870370370370371, 'eval_LAW_f1': 0.7692307692307693, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.5964912280701754, 'eval_VIOLATED BY_recall': 0.5396825396825397, 'eval_VIOLATED BY_f1': 0.5666666666666667, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.4032258064516129, 'eval_VIOLATED ON_recall': 0.4716981132075472, 'eval_VIOLATED ON_f1': 0.43478260869565216, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.6086956521739131, 'eval_VIOLATION_recall': 0.691358024691358, 'eval_VIOLATION_f1': 0.6473988439306358, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.6122448979591837, 'eval_overall_recall': 0.676056338028169, 'eval_overall_f1': 0.642570281124498, 'eval_overall_accuracy': 0.9439651130075963, 'eval_runtime': 9.5708, 'eval_samples_per_second': 30.405, 'eval_steps_per_second': 3.866, 'epoch': 7.0}


 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [10:57<02:21,  1.20it/s]
 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [11:07<02:21,  1.20it/s]

{'eval_loss': 0.26423540711402893, 'eval_LAW_precision': 0.8269230769230769, 'eval_LAW_recall': 0.7962962962962963, 'eval_LAW_f1': 0.8113207547169811, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.6111111111111112, 'eval_VIOLATED BY_recall': 0.5238095238095238, 'eval_VIOLATED BY_f1': 0.5641025641025642, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.43636363636363634, 'eval_VIOLATED ON_recall': 0.4528301886792453, 'eval_VIOLATED ON_f1': 0.4444444444444444, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.6635160680529301, 'eval_VIOLATION_recall': 0.7222222222222222, 'eval_VIOLATION_f1': 0.6916256157635469, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.6657681940700808, 'eval_overall_recall': 0.6957746478873239, 'eval_overall_f1': 0.6804407713498623, 'eval_overall_accuracy': 0.9454656288099034, 'eval_runtime': 9.6527, 'eval_samples_per_second': 30.147, 'eval_steps_per_second': 3.833, 'epoch': 8.0}


 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [12:22<01:09,  1.23it/s]
 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [12:31<01:09,  1.23it/s]

{'eval_loss': 0.2539927065372467, 'eval_LAW_precision': 0.8446601941747572, 'eval_LAW_recall': 0.8055555555555556, 'eval_LAW_f1': 0.8246445497630333, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.6052631578947368, 'eval_VIOLATED BY_recall': 0.7301587301587301, 'eval_VIOLATED BY_f1': 0.6618705035971223, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.4098360655737705, 'eval_VIOLATED ON_recall': 0.4716981132075472, 'eval_VIOLATED ON_f1': 0.43859649122807015, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.6440366972477064, 'eval_VIOLATION_recall': 0.7222222222222222, 'eval_VIOLATION_f1': 0.6808923375363725, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.6484076433121019, 'eval_overall_recall': 0.7169014084507043, 'eval_overall_f1': 0.6809364548494984, 'eval_overall_accuracy': 0.9483259870580512, 'eval_runtime': 9.5379, 'eval_samples_per_second': 30.51, 'eval_steps_per_second': 3.879, 'epoch': 9.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [13:48<00:00,  1.20it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [14:02<00:00,  1.20it/s]

{'eval_loss': 0.26457345485687256, 'eval_LAW_precision': 0.8165137614678899, 'eval_LAW_recall': 0.8240740740740741, 'eval_LAW_f1': 0.8202764976958524, 'eval_LAW_number': 108, 'eval_VIOLATED BY_precision': 0.6176470588235294, 'eval_VIOLATED BY_recall': 0.6666666666666666, 'eval_VIOLATED BY_f1': 0.6412213740458016, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.391304347826087, 'eval_VIOLATED ON_recall': 0.5094339622641509, 'eval_VIOLATED ON_f1': 0.44262295081967207, 'eval_VIOLATED ON_number': 53, 'eval_VIOLATION_precision': 0.6423487544483986, 'eval_VIOLATION_recall': 0.742798353909465, 'eval_VIOLATION_f1': 0.6889312977099238, 'eval_VIOLATION_number': 486, 'eval_overall_precision': 0.6423267326732673, 'eval_overall_recall': 0.7309859154929578, 'eval_overall_f1': 0.6837944664031621, 'eval_overall_accuracy': 0.9487948982462722, 'eval_runtime': 9.6788, 'eval_samples_per_second': 30.066, 'eval_steps_per_second': 3.823, 'epoch': 10.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [14:07<00:00,  1.00it/s]

{'train_runtime': 847.7884, 'train_samples_per_second': 8.009, 'train_steps_per_second': 1.003, 'train_loss': 0.3135357576258042, 'epoch': 10.0}





In [239]:
evaluation_results = evaluate_trained_model(trained_model)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:09<00:00,  3.93it/s]


In [240]:
evaluation_results

{'eval_loss': 0.26457345485687256,
 'eval_LAW_precision': 0.8165137614678899,
 'eval_LAW_recall': 0.8240740740740741,
 'eval_LAW_f1': 0.8202764976958524,
 'eval_LAW_number': 108,
 'eval_VIOLATED BY_precision': 0.6176470588235294,
 'eval_VIOLATED BY_recall': 0.6666666666666666,
 'eval_VIOLATED BY_f1': 0.6412213740458016,
 'eval_VIOLATED BY_number': 63,
 'eval_VIOLATED ON_precision': 0.391304347826087,
 'eval_VIOLATED ON_recall': 0.5094339622641509,
 'eval_VIOLATED ON_f1': 0.44262295081967207,
 'eval_VIOLATED ON_number': 53,
 'eval_VIOLATION_precision': 0.6423487544483986,
 'eval_VIOLATION_recall': 0.742798353909465,
 'eval_VIOLATION_f1': 0.6889312977099238,
 'eval_VIOLATION_number': 486,
 'eval_overall_precision': 0.6423267326732673,
 'eval_overall_recall': 0.7309859154929578,
 'eval_overall_f1': 0.6837944664031621,
 'eval_overall_accuracy': 0.9487948982462722,
 'eval_runtime': 9.8114,
 'eval_samples_per_second': 29.659,
 'eval_steps_per_second': 3.771,
 'epoch': 10.0}

In [241]:
trained_model.save_model(output_dir="legal_lens_finetuned_models/bert_cased_finetuned/")

In [242]:
tokenizer.save_pretrained(save_directory="legal_lens_finetuned_tokenizers/bert_cased_finetuned/")

('legal_lens_finetuned_tokenizers/bert_cased_finetuned/tokenizer_config.json',
 'legal_lens_finetuned_tokenizers/bert_cased_finetuned/special_tokens_map.json',
 'legal_lens_finetuned_tokenizers/bert_cased_finetuned/vocab.txt',
 'legal_lens_finetuned_tokenizers/bert_cased_finetuned/added_tokens.json',
 'legal_lens_finetuned_tokenizers/bert_cased_finetuned/tokenizer.json')

In [243]:
finetuned_tokenizer = AutoTokenizer.from_pretrained("legal_lens_finetuned_tokenizers/bert_cased_finetuned/")
finetuned_model = AutoModelForTokenClassification.from_pretrained("legal_lens_finetuned_models/bert_cased_finetuned/")

In [244]:
# Apply the function to the entire dataset
test_results = test_set.map(
    lambda x: predict_tags(x, finetuned_tokenizer, finetuned_model))

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 380/380 [00:15<00:00, 24.49 examples/s]


## Roberta

In [324]:
model_checkpoint = "FacebookAI/roberta-base"

In [326]:
tokenizer, tokenized_dataset = tokenize_dataset(model_checkpoint, dataset)

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 679/679 [00:00<00:00, 3419.45 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 291/291 [00:00<00:00, 4202.71 examples/s]


In [327]:
trained_model = train_model(model_checkpoint, tokenized_dataset, label_list)

Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, msg_start, len(result))

 10%|â–ˆ         | 85/850 [01:23<10:33,  1.21it/s]

{'eval_loss': 0.6782763600349426, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.0, 'eval_VIOLATION_recall': 0.0, 'eval_VIOLATION_f1': 0.0, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.0, 'eval_overall_recall': 0.0, 'eval_overall_f1': 0.0, 'eval_overall_accuracy': 0.776010339513844, 'eval_runtime': 9.4571, 'eval_samples_per_second': 30.771, 'eval_steps_per_second': 3.912, 'epoch': 1.0}


  _warn_prf(average, modifier, msg_start, len(result))

 20%|â–ˆâ–ˆ        | 170/850 [02:48<09:16,  1.22it/s]

{'eval_loss': 0.3421342670917511, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.2835820895522388, 'eval_VIOLATION_recall': 0.37047353760445684, 'eval_VIOLATION_f1': 0.32125603864734303, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.2835820895522388, 'eval_overall_recall': 0.23498233215547704, 'eval_overall_f1': 0.25700483091787435, 'eval_overall_accuracy': 0.8944176567082567, 'eval_runtime': 9.3778, 'eval_samples_per_second': 31.031, 'eval_steps_per_second': 3.946, 'epoch': 2.0}


 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [04:02<08:08,  1.22it/s]
 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [04:11<08:08,  1.22it/s]

{'eval_loss': 0.22220703959465027, 'eval_LAW_precision': 0.44642857142857145, 'eval_LAW_recall': 0.25510204081632654, 'eval_LAW_f1': 0.3246753246753247, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.4174107142857143, 'eval_VIOLATION_recall': 0.520891364902507, 'eval_VIOLATION_f1': 0.4634448574969021, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.4140625, 'eval_overall_recall': 0.3745583038869258, 'eval_overall_f1': 0.39332096474953615, 'eval_overall_accuracy': 0.9274245662872198, 'eval_runtime': 9.3005, 'eval_samples_per_second': 31.289, 'eval_steps_per_second': 3.978, 'epoch': 3.0}


 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [05:27<06:54,  1.23it/s]
 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [05:36<06:54,  1.23it/s]

{'eval_loss': 0.18026337027549744, 'eval_LAW_precision': 0.6744186046511628, 'eval_LAW_recall': 0.5918367346938775, 'eval_LAW_f1': 0.6304347826086958, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.6111111111111112, 'eval_VIOLATED BY_recall': 0.5689655172413793, 'eval_VIOLATED BY_f1': 0.5892857142857143, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.3333333333333333, 'eval_VIOLATED ON_recall': 0.39215686274509803, 'eval_VIOLATED ON_f1': 0.36036036036036034, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.547979797979798, 'eval_VIOLATION_recall': 0.6044568245125348, 'eval_VIOLATION_f1': 0.5748344370860927, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.5503355704697986, 'eval_overall_recall': 0.5795053003533569, 'eval_overall_f1': 0.5645438898450946, 'eval_overall_accuracy': 0.9419396530297758, 'eval_runtime': 9.1431, 'eval_samples_per_second': 31.827, 'eval_steps_per_second': 4.047, 'epoch': 4.0}


 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [06:50<05:56,  1.19it/s]
 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [06:59<05:56,  1.19it/s]

{'eval_loss': 0.18210674822330475, 'eval_LAW_precision': 0.8045977011494253, 'eval_LAW_recall': 0.7142857142857143, 'eval_LAW_f1': 0.7567567567567568, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.5, 'eval_VIOLATED BY_recall': 0.5344827586206896, 'eval_VIOLATED BY_f1': 0.5166666666666667, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.47619047619047616, 'eval_VIOLATED ON_recall': 0.39215686274509803, 'eval_VIOLATED ON_f1': 0.4301075268817204, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.5965346534653465, 'eval_VIOLATION_recall': 0.6713091922005571, 'eval_VIOLATION_f1': 0.6317169069462647, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.6084033613445378, 'eval_overall_recall': 0.6395759717314488, 'eval_overall_f1': 0.6236003445305771, 'eval_overall_accuracy': 0.9461152259283193, 'eval_runtime': 9.1697, 'eval_samples_per_second': 31.735, 'eval_steps_per_second': 4.035, 'epoch': 5.0}


 59%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰    | 500/850 [08:07<04:56,  1.18it/s]

{'loss': 0.4588, 'grad_norm': 3.458345413208008, 'learning_rate': 2e-05, 'epoch': 5.88}


 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [08:15<04:37,  1.23it/s]
 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [08:25<04:37,  1.23it/s]

{'eval_loss': 0.17915713787078857, 'eval_LAW_precision': 0.8085106382978723, 'eval_LAW_recall': 0.7755102040816326, 'eval_LAW_f1': 0.7916666666666665, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.5862068965517241, 'eval_VIOLATED BY_recall': 0.5862068965517241, 'eval_VIOLATED BY_f1': 0.5862068965517241, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.34375, 'eval_VIOLATED ON_recall': 0.43137254901960786, 'eval_VIOLATED ON_f1': 0.3826086956521739, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.6269430051813472, 'eval_VIOLATION_recall': 0.6740947075208914, 'eval_VIOLATION_f1': 0.6496644295302013, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.6212624584717608, 'eval_overall_recall': 0.6607773851590106, 'eval_overall_f1': 0.6404109589041096, 'eval_overall_accuracy': 0.9509867276432867, 'eval_runtime': 9.4189, 'eval_samples_per_second': 30.895, 'eval_steps_per_second': 3.928, 'epoch': 6.0}


 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [09:39<03:31,  1.21it/s]
 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [09:49<03:31,  1.21it/s]

{'eval_loss': 0.22653014957904816, 'eval_LAW_precision': 0.7872340425531915, 'eval_LAW_recall': 0.7551020408163265, 'eval_LAW_f1': 0.7708333333333333, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.6792452830188679, 'eval_VIOLATED BY_recall': 0.6206896551724138, 'eval_VIOLATED BY_f1': 0.6486486486486486, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.46153846153846156, 'eval_VIOLATED ON_recall': 0.47058823529411764, 'eval_VIOLATED ON_f1': 0.46601941747572817, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.6130653266331658, 'eval_VIOLATION_recall': 0.6796657381615598, 'eval_VIOLATION_f1': 0.6446499339498019, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.6331658291457286, 'eval_overall_recall': 0.6678445229681979, 'eval_overall_f1': 0.6500429922613928, 'eval_overall_accuracy': 0.9455187155142417, 'eval_runtime': 9.2003, 'eval_samples_per_second': 31.629, 'eval_steps_per_second': 4.022, 'epoch': 7.0}


 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [11:03<02:17,  1.24it/s]
 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [11:12<02:17,  1.24it/s]

{'eval_loss': 0.18123936653137207, 'eval_LAW_precision': 0.8541666666666666, 'eval_LAW_recall': 0.8367346938775511, 'eval_LAW_f1': 0.845360824742268, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.7083333333333334, 'eval_VIOLATED BY_recall': 0.5862068965517241, 'eval_VIOLATED BY_f1': 0.6415094339622641, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.5192307692307693, 'eval_VIOLATED ON_recall': 0.5294117647058824, 'eval_VIOLATED ON_f1': 0.5242718446601942, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.6657894736842105, 'eval_VIOLATION_recall': 0.7047353760445683, 'eval_VIOLATION_f1': 0.6847090663058186, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.6875, 'eval_overall_recall': 0.6996466431095406, 'eval_overall_f1': 0.6935201401050788, 'eval_overall_accuracy': 0.9544663717254064, 'eval_runtime': 9.1535, 'eval_samples_per_second': 31.791, 'eval_steps_per_second': 4.042, 'epoch': 8.0}


 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [12:26<01:10,  1.21it/s]
 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [12:36<01:10,  1.21it/s]

{'eval_loss': 0.21662026643753052, 'eval_LAW_precision': 0.8173076923076923, 'eval_LAW_recall': 0.8673469387755102, 'eval_LAW_f1': 0.8415841584158416, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.6142857142857143, 'eval_VIOLATED BY_recall': 0.7413793103448276, 'eval_VIOLATED BY_f1': 0.671875, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.40298507462686567, 'eval_VIOLATED ON_recall': 0.5294117647058824, 'eval_VIOLATED ON_f1': 0.45762711864406785, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.5905882352941176, 'eval_VIOLATION_recall': 0.6991643454038997, 'eval_VIOLATION_f1': 0.6403061224489797, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.6096096096096096, 'eval_overall_recall': 0.7173144876325088, 'eval_overall_f1': 0.6590909090909092, 'eval_overall_accuracy': 0.9495451608092658, 'eval_runtime': 9.0622, 'eval_samples_per_second': 32.111, 'eval_steps_per_second': 4.083, 'epoch': 9.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [13:50<00:00,  1.22it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [14:01<00:00,  1.22it/s]

{'eval_loss': 0.2228051722049713, 'eval_LAW_precision': 0.8817204301075269, 'eval_LAW_recall': 0.8367346938775511, 'eval_LAW_f1': 0.8586387434554974, 'eval_LAW_number': 98, 'eval_VIOLATED BY_precision': 0.7096774193548387, 'eval_VIOLATED BY_recall': 0.7586206896551724, 'eval_VIOLATED BY_f1': 0.7333333333333333, 'eval_VIOLATED BY_number': 58, 'eval_VIOLATED ON_precision': 0.6, 'eval_VIOLATED ON_recall': 0.5882352941176471, 'eval_VIOLATED ON_f1': 0.5940594059405941, 'eval_VIOLATED ON_number': 51, 'eval_VIOLATION_precision': 0.6190476190476191, 'eval_VIOLATION_recall': 0.724233983286908, 'eval_VIOLATION_f1': 0.6675224646983312, 'eval_VIOLATION_number': 359, 'eval_overall_precision': 0.6656, 'eval_overall_recall': 0.734982332155477, 'eval_overall_f1': 0.6985726280436608, 'eval_overall_accuracy': 0.9533230600984243, 'eval_runtime': 9.1178, 'eval_samples_per_second': 31.916, 'eval_steps_per_second': 4.058, 'epoch': 10.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [14:02<00:00,  1.01it/s]

{'train_runtime': 842.4538, 'train_samples_per_second': 8.06, 'train_steps_per_second': 1.009, 'train_loss': 0.2861541938781738, 'epoch': 10.0}





In [328]:
evaluation_results = evaluate_trained_model(trained_model)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:09<00:00,  4.01it/s]


In [329]:
evaluation_results

{'eval_loss': 0.2228051722049713,
 'eval_LAW_precision': 0.8817204301075269,
 'eval_LAW_recall': 0.8367346938775511,
 'eval_LAW_f1': 0.8586387434554974,
 'eval_LAW_number': 98,
 'eval_VIOLATED BY_precision': 0.7096774193548387,
 'eval_VIOLATED BY_recall': 0.7586206896551724,
 'eval_VIOLATED BY_f1': 0.7333333333333333,
 'eval_VIOLATED BY_number': 58,
 'eval_VIOLATED ON_precision': 0.6,
 'eval_VIOLATED ON_recall': 0.5882352941176471,
 'eval_VIOLATED ON_f1': 0.5940594059405941,
 'eval_VIOLATED ON_number': 51,
 'eval_VIOLATION_precision': 0.6190476190476191,
 'eval_VIOLATION_recall': 0.724233983286908,
 'eval_VIOLATION_f1': 0.6675224646983312,
 'eval_VIOLATION_number': 359,
 'eval_overall_precision': 0.6656,
 'eval_overall_recall': 0.734982332155477,
 'eval_overall_f1': 0.6985726280436608,
 'eval_overall_accuracy': 0.9533230600984243,
 'eval_runtime': 9.6556,
 'eval_samples_per_second': 30.138,
 'eval_steps_per_second': 3.832,
 'epoch': 10.0}

In [330]:
trained_model.save_model(output_dir="legal_lens_finetuned_models/roberta_finetuned/")

In [331]:
tokenizer.save_pretrained(save_directory="legal_lens_finetuned_tokenizers/roberta_finetuned/")

('legal_lens_finetuned_tokenizers/roberta_finetuned/tokenizer_config.json',
 'legal_lens_finetuned_tokenizers/roberta_finetuned/special_tokens_map.json',
 'legal_lens_finetuned_tokenizers/roberta_finetuned/vocab.json',
 'legal_lens_finetuned_tokenizers/roberta_finetuned/merges.txt',
 'legal_lens_finetuned_tokenizers/roberta_finetuned/added_tokens.json',
 'legal_lens_finetuned_tokenizers/roberta_finetuned/tokenizer.json')

In [332]:
finetuned_tokenizer = AutoTokenizer.from_pretrained("legal_lens_finetuned_tokenizers/roberta_finetuned/")
finetuned_model = AutoModelForTokenClassification.from_pretrained("legal_lens_finetuned_models/roberta_finetuned")

In [333]:
# Apply the function to the entire dataset
test_results = test_set.map(
    lambda x: predict_tags(x, finetuned_tokenizer, finetuned_model))

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 380/380 [00:15<00:00, 24.96 examples/s]


In [335]:
test_results.to_csv("NER_test_set_results_roberta.csv")

Creating CSV from Arrow format: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1/1 [00:00<00:00, 26.09ba/s]


492164

## FLAN T5

In [259]:
model_checkpoint = "google/flan-t5-base"

In [260]:
tokenizer, tokenized_dataset = tokenize_dataset(model_checkpoint, dataset)

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 679/679 [00:00<00:00, 3969.65 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 291/291 [00:00<00:00, 3928.39 examples/s]


In [261]:
trained_model = train_model(model_checkpoint, tokenized_dataset, label_list)

Some weights of T5ForTokenClassification were not initialized from the model checkpoint at google/flan-t5-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
 10%|â–ˆ         | 85/850 [01:34<13:31,  1.06s/it]
 10%|â–ˆ         | 85/850 [01:47<13:31,  1.06s/it]

{'eval_loss': 3.8418359756469727, 'eval_LAW_precision': 0.00019428793471925395, 'eval_LAW_recall': 0.008928571428571428, 'eval_LAW_f1': 0.00038030043734550294, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.000686655985351339, 'eval_VIOLATED BY_recall': 0.047619047619047616, 'eval_VIOLATED BY_f1': 0.0013537906137184117, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0005629045876723895, 'eval_VIOLATED ON_recall': 0.038461538461538464, 'eval_VIOLATED ON_f1': 0.0011095700416088763, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.008405438813349814, 'eval_VIOLATION_recall': 0.07538802660753881, 'eval_VIOLATION_f1': 0.015124555160142347, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.002337267734018932, 'eval_overall_recall': 0.058997050147492625, 'eval_overall_f1': 0.004496402877697842, 'eval_overall_accuracy': 0.1591245680434425, 'eval_runtime': 13.3459, 'eval_samples_per_second': 21.804, 'eval_steps_per_second': 2.772, 'epoch': 1.0}


 20%|â–ˆâ–ˆ        | 170/850 [03:21<11:48,  1.04s/it] 
 20%|â–ˆâ–ˆ        | 170/850 [03:34<11:48,  1.04s/it]

{'eval_loss': 1.3803106546401978, 'eval_LAW_precision': 0.0006993006993006993, 'eval_LAW_recall': 0.008928571428571428, 'eval_LAW_f1': 0.001297016861219196, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.00338409475465313, 'eval_VIOLATED BY_recall': 0.031746031746031744, 'eval_VIOLATED BY_f1': 0.006116207951070336, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.009124087591240875, 'eval_VIOLATION_recall': 0.04434589800443459, 'eval_VIOLATION_f1': 0.01513431706394249, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.004878048780487805, 'eval_overall_recall': 0.03392330383480826, 'eval_overall_f1': 0.008529575375486742, 'eval_overall_accuracy': 0.6776369919368109, 'eval_runtime': 13.3517, 'eval_samples_per_second': 21.795, 'eval_steps_per_second': 2.771, 'epoch': 2.0}


 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [05:09<10:19,  1.04s/it]  
 30%|â–ˆâ–ˆâ–ˆ       | 255/850 [05:22<10:19,  1.04s/it]

{'eval_loss': 0.8169323205947876, 'eval_LAW_precision': 0.0, 'eval_LAW_recall': 0.0, 'eval_LAW_f1': 0.0, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.025, 'eval_VIOLATED BY_recall': 0.015873015873015872, 'eval_VIOLATED BY_f1': 0.019417475728155338, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.01293103448275862, 'eval_VIOLATION_recall': 0.06651884700665188, 'eval_VIOLATION_f1': 0.021652832912306026, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.009645301804604853, 'eval_overall_recall': 0.045722713864306784, 'eval_overall_f1': 0.015930113052415207, 'eval_overall_accuracy': 0.7738193187428007, 'eval_runtime': 13.4073, 'eval_samples_per_second': 21.705, 'eval_steps_per_second': 2.76, 'epoch': 3.0}


 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [06:56<08:51,  1.04s/it]
 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 340/850 [07:09<08:51,  1.04s/it]

{'eval_loss': 0.5588050484657288, 'eval_LAW_precision': 0.017857142857142856, 'eval_LAW_recall': 0.044642857142857144, 'eval_LAW_f1': 0.025510204081632654, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.020441988950276244, 'eval_VIOLATION_recall': 0.082039911308204, 'eval_VIOLATION_f1': 0.03272888102609465, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.019345923537540305, 'eval_overall_recall': 0.061946902654867256, 'eval_overall_f1': 0.029484029484029485, 'eval_overall_accuracy': 0.8324831331248972, 'eval_runtime': 13.3731, 'eval_samples_per_second': 21.76, 'eval_steps_per_second': 2.767, 'epoch': 4.0}


 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [08:43<07:23,  1.04s/it]
 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 425/850 [08:57<07:23,  1.04s/it]

{'eval_loss': 0.41232287883758545, 'eval_LAW_precision': 0.3181818181818182, 'eval_LAW_recall': 0.1875, 'eval_LAW_f1': 0.23595505617977533, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.054838709677419356, 'eval_VIOLATION_recall': 0.188470066518847, 'eval_VIOLATION_f1': 0.08495752123938032, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.06287069988137604, 'eval_overall_recall': 0.15634218289085547, 'eval_overall_f1': 0.08967851099830797, 'eval_overall_accuracy': 0.8720174428171795, 'eval_runtime': 13.3403, 'eval_samples_per_second': 21.814, 'eval_steps_per_second': 2.774, 'epoch': 5.0}


 59%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–‰    | 500/850 [10:20<06:17,  1.08s/it]

{'loss': 2.085, 'grad_norm': 1.5019505023956299, 'learning_rate': 2e-05, 'epoch': 5.88}


 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [10:31<05:57,  1.05s/it]
 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 510/850 [10:44<05:57,  1.05s/it]

{'eval_loss': 0.3462766706943512, 'eval_LAW_precision': 0.21794871794871795, 'eval_LAW_recall': 0.15178571428571427, 'eval_LAW_f1': 0.17894736842105263, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.0, 'eval_VIOLATED BY_recall': 0.0, 'eval_VIOLATED BY_f1': 0.0, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.0, 'eval_VIOLATED ON_recall': 0.0, 'eval_VIOLATED ON_f1': 0.0, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.15134529147982062, 'eval_VIOLATION_recall': 0.29933481152993346, 'eval_VIOLATION_f1': 0.20104244229337304, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.15338042381432895, 'eval_overall_recall': 0.22418879056047197, 'eval_overall_f1': 0.1821449970041941, 'eval_overall_accuracy': 0.9101941747572816, 'eval_runtime': 13.4168, 'eval_samples_per_second': 21.689, 'eval_steps_per_second': 2.758, 'epoch': 6.0}


 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [12:19<04:25,  1.04s/it]
 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 595/850 [12:32<04:25,  1.04s/it]

{'eval_loss': 0.30479761958122253, 'eval_LAW_precision': 0.2871287128712871, 'eval_LAW_recall': 0.25892857142857145, 'eval_LAW_f1': 0.2723004694835681, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.06060606060606061, 'eval_VIOLATED BY_recall': 0.031746031746031744, 'eval_VIOLATED BY_f1': 0.041666666666666664, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.03333333333333333, 'eval_VIOLATED ON_recall': 0.019230769230769232, 'eval_VIOLATED ON_f1': 0.024390243902439025, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.18911917098445596, 'eval_VIOLATION_recall': 0.3237250554323725, 'eval_VIOLATION_f1': 0.23875715453802127, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.19017094017094016, 'eval_overall_recall': 0.26253687315634217, 'eval_overall_f1': 0.2205700123915737, 'eval_overall_accuracy': 0.9220421260490373, 'eval_runtime': 13.5778, 'eval_samples_per_second': 21.432, 'eval_steps_per_second': 2.725, 'epoch': 7.0}


 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [14:05<02:51,  1.01s/it]
 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 680/850 [14:18<02:51,  1.01s/it]

{'eval_loss': 0.2844420075416565, 'eval_LAW_precision': 0.4326923076923077, 'eval_LAW_recall': 0.4017857142857143, 'eval_LAW_f1': 0.4166666666666667, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.06060606060606061, 'eval_VIOLATED BY_recall': 0.031746031746031744, 'eval_VIOLATED BY_f1': 0.041666666666666664, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.027777777777777776, 'eval_VIOLATED ON_recall': 0.019230769230769232, 'eval_VIOLATED ON_f1': 0.02272727272727273, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.22117962466487937, 'eval_VIOLATION_recall': 0.36585365853658536, 'eval_VIOLATION_f1': 0.2756892230576441, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.23177366702937977, 'eval_overall_recall': 0.3141592920353982, 'eval_overall_f1': 0.2667501565435191, 'eval_overall_accuracy': 0.9285831824913608, 'eval_runtime': 13.0941, 'eval_samples_per_second': 22.224, 'eval_steps_per_second': 2.826, 'epoch': 8.0}


 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [15:49<01:25,  1.01s/it]
 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 765/850 [16:02<01:25,  1.01s/it]

{'eval_loss': 0.273849755525589, 'eval_LAW_precision': 0.4854368932038835, 'eval_LAW_recall': 0.44642857142857145, 'eval_LAW_f1': 0.4651162790697675, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.08108108108108109, 'eval_VIOLATED BY_recall': 0.047619047619047616, 'eval_VIOLATED BY_f1': 0.06, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.05263157894736842, 'eval_VIOLATED ON_recall': 0.038461538461538464, 'eval_VIOLATED ON_f1': 0.044444444444444446, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.22163588390501318, 'eval_VIOLATION_recall': 0.37250554323725055, 'eval_VIOLATION_f1': 0.27791563275434245, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.23824786324786323, 'eval_overall_recall': 0.32890855457227136, 'eval_overall_f1': 0.27633209417596033, 'eval_overall_accuracy': 0.9319565575119302, 'eval_runtime': 12.9778, 'eval_samples_per_second': 22.423, 'eval_steps_per_second': 2.851, 'epoch': 9.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [17:39<00:00,  1.18s/it]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [17:56<00:00,  1.18s/it]

{'eval_loss': 0.2719593346118927, 'eval_LAW_precision': 0.49514563106796117, 'eval_LAW_recall': 0.45535714285714285, 'eval_LAW_f1': 0.47441860465116287, 'eval_LAW_number': 112, 'eval_VIOLATED BY_precision': 0.05555555555555555, 'eval_VIOLATED BY_recall': 0.031746031746031744, 'eval_VIOLATED BY_f1': 0.04040404040404041, 'eval_VIOLATED BY_number': 63, 'eval_VIOLATED ON_precision': 0.02564102564102564, 'eval_VIOLATED ON_recall': 0.019230769230769232, 'eval_VIOLATED ON_f1': 0.02197802197802198, 'eval_VIOLATED ON_number': 52, 'eval_VIOLATION_precision': 0.22576361221779548, 'eval_VIOLATION_recall': 0.376940133037694, 'eval_VIOLATION_f1': 0.28239202657807305, 'eval_VIOLATION_number': 451, 'eval_overall_precision': 0.24060150375939848, 'eval_overall_recall': 0.3303834808259587, 'eval_overall_f1': 0.27843380981976384, 'eval_overall_accuracy': 0.9323679447095606, 'eval_runtime': 15.0725, 'eval_samples_per_second': 19.307, 'eval_steps_per_second': 2.455, 'epoch': 10.0}


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 850/850 [17:57<00:00,  1.27s/it]

{'train_runtime': 1077.5198, 'train_samples_per_second': 6.302, 'train_steps_per_second': 0.789, 'train_loss': 1.3680954248764936, 'epoch': 10.0}





In [262]:
evaluation_results = evaluate_trained_model(trained_model)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 37/37 [00:13<00:00,  2.74it/s]


In [263]:
evaluation_results

{'eval_loss': 0.2719593346118927,
 'eval_LAW_precision': 0.49514563106796117,
 'eval_LAW_recall': 0.45535714285714285,
 'eval_LAW_f1': 0.47441860465116287,
 'eval_LAW_number': 112,
 'eval_VIOLATED BY_precision': 0.05555555555555555,
 'eval_VIOLATED BY_recall': 0.031746031746031744,
 'eval_VIOLATED BY_f1': 0.04040404040404041,
 'eval_VIOLATED BY_number': 63,
 'eval_VIOLATED ON_precision': 0.02564102564102564,
 'eval_VIOLATED ON_recall': 0.019230769230769232,
 'eval_VIOLATED ON_f1': 0.02197802197802198,
 'eval_VIOLATED ON_number': 52,
 'eval_VIOLATION_precision': 0.22576361221779548,
 'eval_VIOLATION_recall': 0.376940133037694,
 'eval_VIOLATION_f1': 0.28239202657807305,
 'eval_VIOLATION_number': 451,
 'eval_overall_precision': 0.24060150375939848,
 'eval_overall_recall': 0.3303834808259587,
 'eval_overall_f1': 0.27843380981976384,
 'eval_overall_accuracy': 0.9323679447095606,
 'eval_runtime': 15.355,
 'eval_samples_per_second': 18.952,
 'eval_steps_per_second': 2.41,
 'epoch': 10.0}

In [270]:
trained_model.save_model(output_dir="legal_lens_finetuned_models/flan_t5_base_finetuned/")

In [271]:
tokenizer.save_pretrained(save_directory="legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/")

('legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/tokenizer_config.json',
 'legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/special_tokens_map.json',
 'legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/spiece.model',
 'legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/added_tokens.json',
 'legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/tokenizer.json')

In [272]:
finetuned_tokenizer = AutoTokenizer.from_pretrained("legal_lens_finetuned_tokenizers/flan_t5_base_finetuned/")
finetuned_model = AutoModelForTokenClassification.from_pretrained("legal_lens_finetuned_models/flan_t5_base_finetuned/")

In [273]:
# Apply the function to the entire dataset
test_results = test_set.map(
    lambda x: predict_tags(x, finetuned_tokenizer, finetuned_model))

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 380/380 [00:21<00:00, 17.70 examples/s]
