## Iteractively refine dataset labels

In [2]:
import os
import pathlib
import typing as t
import random
import requests
import json as json_package
import collections
import glob

import IPython.display
import ipywidgets
import datasets
import tokenizers
import ipywidgets
import colorama
import tqdm

import segmentador
from config import *
import interactive_labeling


%load_ext autoreload
%autoreload 2


CN = colorama.Fore.RED
CT = colorama.Fore.YELLOW
CLN = colorama.Fore.CYAN
CSR = colorama.Style.RESET_ALL


VOCAB_SIZE = 6000
AUTOSAVE_IN_N_INSTANCES = 20


DATASET_SPLIT = "test"
CACHED_INDEX_FILENAME = f"emendas_variadas_{DATASET_SPLIT}_refined_indices.csv"
SKIPPED_INDEX_FILENAME = f"emendas_variadas_{DATASET_SPLIT}_skipped_indices.csv"
BRUTE_DATASET_DIR = "../data"

# TARGET_DATASET_NAME = f"df_tokenized_split_0_120000_{VOCAB_SIZE}_resplit"
# TARGET_DATASET_NAME = "ccjs_segmentados_train_test_splits"
TARGET_DATASET_NAME = "emendas_variadas_segmentadas_train_test_splits"

REFINED_DATASET_DIR = os.path.join(BRUTE_DATASET_DIR, "refined_datasets", TARGET_DATASET_NAME)

BRUTE_DATASET_URI = os.path.join(BRUTE_DATASET_DIR, TARGET_DATASET_NAME)
CACHED_INDEX_URI = os.path.join(REFINED_DATASET_DIR, CACHED_INDEX_FILENAME)
SKIPPED_INDEX_URI = os.path.join(REFINED_DATASET_DIR, SKIPPED_INDEX_FILENAME)


assert BRUTE_DATASET_URI != REFINED_DATASET_DIR


logit_model = segmentador.BERTSegmenter(
    uri_model=f"../segmenter_checkpoint_v2/4_{VOCAB_SIZE}_layer_model_finetuned_emendas",
    device="cpu",
)

pbar_dump = tqdm.auto.tqdm(desc="Instances until next dump", total=AUTOSAVE_IN_N_INSTANCES)
it_counter = 0

lock_save = False
random.seed(17)

Instances until next dump:   0%|          | 0/20 [00:00<?, ?it/s]

In [3]:
tokenizer = tokenizers.Tokenizer.from_file(f"../tokenizers/{VOCAB_SIZE}_subwords/tokenizer.json")
tokenizer.get_vocab_size()

6000

## Load dataset

In [4]:
try:
    with open(CACHED_INDEX_URI, "r") as f_index:
        cached_indices = set(map(int, f_index.read().split(",")))

    print(f"Loaded {len(cached_indices)} indices from disk.")

except FileNotFoundError:
    cached_indices = set()


try:
    with open(SKIPPED_INDEX_URI, "r") as f_index:
        skipped_indices = set(map(int, f_index.read().split(",")))

    print(f"Loaded {len(skipped_indices)} skipped indices from disk.")

except (FileNotFoundError, ValueError):
    skipped_indices = set()


new_refined_instances = collections.defaultdict(list)
df_brute = datasets.load_from_disk(BRUTE_DATASET_URI)
df_brute

Loaded 200 indices from disk.


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 11871
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})

## Iteractive refinement

In [4]:
df_split = df_brute[DATASET_SPLIT]

fn_rand = lambda: random.randint(0, df_split.num_rows)

df_view_labels = df_split["labels"]
df_view_input_ids = df_split["input_ids"]

cls2id = {
    "no-op": 0,
    "seg": 1,
    "n-start": 2,
    "n-end": 3,
}


def print_labels(input_ids, labels, input_is_tokens: bool = False):
    seg_counter = 1

    print(end=CSR)
    print(end=f"{CLN}{seg_counter}.{CSR} ")

    if not input_is_tokens:
        tokens = list(map(tokenizer.id_to_token, input_ids))

    else:
        tokens = input_ids

    for i, (tok, lab) in enumerate(zip(tokens, labels)):
        if lab == cls2id["seg"]:
            seg_counter += 1
            print("\n\n", end=f"{CLN}{seg_counter}.{CSR} ")

        if lab == cls2id["n-start"]:
            print(end=CN)

        if lab == cls2id["n-end"]:
            print(end=CSR)

        print(tok, end=" ")

    return tokens


def dump_refined_dataset():
    new_subset = datasets.Dataset.from_dict(new_refined_instances, split=DATASET_SPLIT)
    shard_id = len(glob.glob(os.path.join(REFINED_DATASET_DIR, f"{DATASET_SPLIT}_*")))
    REFINED_DATASET_SHARD_URI = os.path.join(REFINED_DATASET_DIR, f"{DATASET_SPLIT}_{shard_id}")
    new_subset.save_to_disk(REFINED_DATASET_SHARD_URI)
    new_refined_instances.clear()

    with open(CACHED_INDEX_URI, "w") as f_index:
        f_index.write(",".join(map(str, sorted(cached_indices))))

    with open(SKIPPED_INDEX_URI, "w") as f_index:
        f_index.write(",".join(map(str, sorted(skipped_indices))))

    it_counter = 0

    print(f"Saved progress in '{REFINED_DATASET_SHARD_URI}'.")


def fn_run_cell(index_shift: int):
    js = IPython.display.Javascript(
        f"Jupyter.notebook.execute_cells([IPython.notebook.get_selected_index()+{index_shift}])"
    )
    IPython.display.display(js)


def fn_run_last_cell(_):
    js = IPython.display.Javascript(
        "Jupyter.notebook.execute_cells([IPython.notebook.ncells() - 16])"
    )
    IPython.display.display(js)


def fn_skip_instance(_):
    if id_ not in cached_indices:
        skipped_indices.add(id_)
        fn_run_cell(1)


button_run_cell = ipywidgets.Button(
    description="Fetch new instance",
    tooltip="Run cell below to fetch a random instance",
    layout=ipywidgets.Layout(width="20%", height="48px", margin="0 0 0 5%"),
)
button_run_cell.on_click(lambda _: fn_run_cell(1))

button_run_cell_b = ipywidgets.Button(
    description="Fetch refined instance",
    tooltip="Fetch refined instance from front-end (run cell below)",
    style=dict(button_color="lightgreen"),
)
button_run_cell_b.layout = button_run_cell.layout
button_run_cell_b.on_click(lambda _: fn_run_cell(1))

button_edit_instance = ipywidgets.Button(
    description="Edit instance",
    tooltip="Send instance to interactive front-end for refinement",
    style=dict(button_color="lightblue"),
)
button_edit_instance.layout = button_run_cell.layout
button_edit_instance.on_click(lambda _: fn_run_cell(2))

button_save = ipywidgets.Button(
    description="Save test instance",
    style=dict(button_color="salmon"),
    tooltip="Run notebook last cell (triggers code to save instance in curated dataset)",
)
button_save.layout = button_run_cell.layout
button_save.on_click(fn_run_last_cell)

button_skip = ipywidgets.Button(
    description="Skip instance",
    style=dict(button_color="black"),
    tooltip="Skip instance",
)
button_skip.layout = button_run_cell.layout
button_skip.on_click(fn_skip_instance)

In [5]:
import scipy.special
import numpy as np
import tqdm
import torch
import transformers
import os


m = 140000 if DATASET_SPLIT == "train" else 2100
k = 0
cached_margins_filename = f"cached_margins_emendas_variadas_{DATASET_SPLIT}_{m}_{k}.txt"
APPLY_ACTIVE_LEARNING = False


if APPLY_ACTIVE_LEARNING and DATASET_SPLIT != "test":
    logit_model._model = torch.load("bert_v2_4_layers_logit_model_finetuned_emendas7.pt")
    logit_model.model.to("cuda")
    logit_model.device = "cuda"
    
    if os.path.exists(cached_margins_filename):
        with open(cached_margins_filename, "r") as f_in:
            margins = np.asfarray(f_in.readlines())

    else:
        torch.cuda.empty_cache()
        
        margins = np.full(len(df_view_input_ids), fill_value=np.inf)

        compute_diff_tokens = False
        diff_80 = 0
        total_tokens_80 = 1e-8

        pbar = tqdm.auto.tqdm(
            enumerate(df_view_input_ids[-1 - k : -m - 1 - k : -1], 1 + k),
            total=min(m, len(df_view_input_ids) - k),
        )

        for i, text in pbar:
            logits = logit_model(
                tokenizer.decode(text),
                batch_size=128,
                return_logits=True,
                regex_justificativa="@@@@@@",
            ).logits
            
            probs = scipy.special.softmax(logits, axis=-1)
            margin = np.diff(np.sort(probs, axis=-1)[:, [-2, -1]]).ravel()

            try:
                true_labels = np.asarray(df_view_labels[-i], dtype=int)
                not_middle_word = true_labels != -100

                if compute_diff_tokens:
                    try:
                        margin_80_inds = np.flatnonzero(
                            np.logical_and(not_middle_word, margin >= 0.90)
                        )
                        high_conf_preds = np.argmax(probs[margin_80_inds], axis=-1)
                        diff_inds = np.flatnonzero(true_labels[margin_80_inds] != high_conf_preds)
                        diff_80 += diff_inds.size
                        total_tokens_80 += margin_80_inds.size

                        if diff_inds.size:
                            new_labels.append(
                                (
                                    i,
                                    list(
                                        zip(margin_80_inds[diff_inds], high_conf_preds[diff_inds])
                                    ),
                                )
                            )

                    except ValueError:
                        pass

                margin = margin[not_middle_word]

            except IndexError as err:
                margin = [np.inf, np.inf]

            try:
                margins[-i] = float(np.quantile(margin, 0.01))
            except IndexError:
                margins[-i] = 0.0

            if compute_diff_tokens:
                pbar.set_description(
                    f"50: {100. * diff_80 / total_tokens_80:.2f}% ({diff_80} of {int(total_tokens_80)})"
                )

    ids_to_fetch = np.argsort(margins)

In [7]:
if APPLY_ACTIVE_LEARNING and not os.path.exists(cached_margins_filename):
    with open(cached_margins_filename, "w") as f_out:
        f_out.write("\n".join(map(lambda x: f"{x:.6f}", margins)))

## Interative refinery with low margin (active learning)

In [6]:
IPython.display.display(
    ipywidgets.HBox((button_run_cell, button_edit_instance, button_save, button_skip))
);

HBox(children=(Button(description='Fetch new instance', layout=Layout(height='48px', margin='0 0 0 5%', width=…

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [276]:
if not APPLY_ACTIVE_LEARNING or DATASET_SPLIT == "test":
    j = 0
    while j in cached_indices or j in skipped_indices:
        j += 1

    id_ = j

else:
    j = 0
    while ids_to_fetch[j] in cached_indices or ids_to_fetch[j] in skipped_indices:
        j += 1

    id_ = int(ids_to_fetch[j])


# id_ = np.random.randint(0, len(ids_to_fetch))
    
input_ids = df_view_input_ids[id_]
labels = df_view_labels[id_]
print(id_)
tokens = print_labels(input_ids, labels)

if APPLY_ACTIVE_LEARNING and DATASET_SPLIT != "test":
    print("\n\nmargin:", margins[ids_to_fetch[j]])

IndexError: list index out of range

In [273]:
logits = logit_model(df_split[id_], return_logits=True, batch_size=32, regex_justificativa="@@@@").logits
interactive_labeling.open_example(
    tokens, labels, logits=logits if DATASET_SPLIT != "test" else None
)
lock_save = True

In [7]:
IPython.display.display(ipywidgets.HBox((button_run_cell_b, button_save)));

HBox(children=(Button(description='Fetch refined instance', layout=Layout(height='48px', margin='0 0 0 5%', wi…

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [274]:
ret = interactive_labeling.retrieve_refined_example()
assert len(ret["labels"]) == len(tokens)
labels = ret["labels"]
print_labels(tokens, labels, input_is_tokens=True)
lock_save = False

[0m[36m1.[0m nº 6 , de 20 de março de 2020 , e dá outras providências . 

[36m2.[0m Ass ##inar ##am eletr ##on ##icamente o documento CD ##20 ##69 ##55 ##10 ##76 ##00 , nesta ordem : 

[36m3.[0m 1 Dep . João Ro ##ma ( RE ##P ##U ##BL ##IC / BA ) 

[36m4.[0m 2 Dep . J ##ho ##na ##ta ##n de J ##esus ( RE ##P ##U ##BL ##IC / R ##R ) - L ##Í ##DE ##R do RE ##P ##U ##BL ##IC [31m* - ( P _ 50 ##2 ##7 ) 

[36m5.[0m 3 Dep . E ##ri ##k ##a K ##o ##k ##a ##y ( PT / DF ) 

[36m6.[0m 4 Dep . Ped ##ro Lu ##p ##ion ( DEM / PR ) - VI ##CE - L ##Í ##DE ##R do B ##lo ##co PL , PP , PSD , M ##DB , DEM , SO ##L ##ID ##AR ##I ##ED ##AD ##E , PT ##B , PRO ##S , A ##VA ##N ##TE [31m* Ch ##ance ##la eletr ##ônica do ( a ) deput ##ado ( a ) , nos termos de deleg ##ação regulament ##ada no Ato da mes ##a n . 25 de 2015 . CÂMARA DOS DEPUTADOS Inf ##ol ##eg - Aut ##ent ##ica ##dor Do ##cu ##mento eletrônico assin ##ado por João Ro ##ma ( RE ##P ##U ##BL ##IC / BA ) , através do ponto SD ##R _ 56 #

In [275]:
redo = False

if (redo or (not lock_save and id_ not in cached_indices)) and id_ < len(df_split):
    if not redo:
        it_counter += 1
    
        cached_indices.add(id_)

        for key, val in df_split[id_].items():
            if key != "labels":
                new_refined_instances[key].append(val.copy())

            else:
                new_refined_instances[key].append(labels)

        pbar_dump.update()
        
    else:
        for key, val in df_split[id_].items():
            if key != "labels":
                new_refined_instances[key][-1] = val.copy()

            else:
                new_refined_instances[key][-1] = labels


if it_counter % AUTOSAVE_IN_N_INSTANCES == 0 or j >= len(df_view_input_ids):
    dump_refined_dataset()
    pbar_dump.reset()


print(pbar_dump)

Saving the dataset (0/1 shards):   0%|          | 0/20 [00:00<?, ? examples/s]

Saved progress in '../data/refined_datasets/emendas_variadas_segmentadas_train_test_splits/test_9'.
Instances until next dump:   0%|          | 0/20 [00:00<?, ?it/s]


## Merge curated data shards

In [277]:
shard_uris = glob.glob(os.path.join(REFINED_DATASET_DIR, f"{DATASET_SPLIT}_*"))

all_dsets = []

for shard_uri in shard_uris:
    try:
        shard = datasets.Dataset.load_from_disk(shard_uri)
    except Exception:
        shard = None

    if shard and shard.num_rows:
        all_dsets.append(shard)

merged_dset = datasets.concatenate_datasets(all_dsets)
print(merged_dset.num_rows)


# def split_instances_into_1024_tokens_max(x):
#     num_tokens = len(x["labels"][0])
#     if num_tokens <= 1024: return x
#     new_insts = collections.defaultdict(list)
#     for key, val in x.items():
#         val = val[0]
#         i = 0
#         noise_cliffhanger = False
        
#         while i < len(val):
#             split = val[i:i + 1024]
#             new_insts[key].append(split)
#             i += 1024
            
#             if key == "labels":
#                 try: first_noise_end = split.index(3)
#                 except ValueError: first_noise_end = float("+inf")
                    
#                 try: j = split.index(1)
#                 except ValueError: j = float("+inf")
#                 try: k = split.index(2)
#                 except ValueError: k = float("+inf")
                
#                 if noise_cliffhanger or first_noise_end < min(j, k):
#                     split[split.index(0)] = 2
                    
#                 noise_cliffhanger = False
                
#                 split_r = split[::-1]
#                 try: last_noise_start = split_r.index(2)
#                 except ValueError: last_noise_start = float("+inf")
                
#                 try: j = split_r.index(1)
#                 except ValueError: j = float("+inf")
#                 try: k = split_r.index(3)
#                 except ValueError: k = float("+inf")

#                 noise_cliffhanger = last_noise_start < min(j, k)
    
#     return new_insts


# merged_dset = merged_dset.map(split_instances_into_1024_tokens_max, batch_size=1, batched=True)
# print(len(merged_dset))

output_uri_merged = os.path.join(
    REFINED_DATASET_DIR,
    f"combined_{DATASET_SPLIT}_{len(shard_uris)}_parts_{merged_dset.num_rows}_instances",
)

merged_dset.save_to_disk(output_uri_merged)

200


Saving the dataset (0/1 shards):   0%|          | 0/200 [00:00<?, ? examples/s]

## Fine-tune for curated samples

In [9]:
import torch
import torch.nn
import copy
import torchmetrics
import transformers
import pandas as pd
import datasets
import os
import optimum.onnxruntime
import onnxruntime


# df_curated_train = datasets.Dataset.load_from_disk(
#     os.path.join(REFINED_DATASET_DIR, "combined_train_102_parts_4221_instances")
# )

# df_curated_train = datasets.concatenate_datasets([
#     df_curated_train,
#     datasets.Dataset.load_from_disk("./final_curated_dataset_for_hyperparameter_tuning/eval"),
# ])

# df_curated_test = datasets.Dataset.load_from_disk(
#     os.path.join(".", "final_curated_dataset_for_hyperparameter_tuning/test")
# )

V2_DATADIR = "../data/refined_datasets/df_tokenized_split_0_120000_6000_resplit"
EMENDAS_PREV_DATADIR = "../data/refined_datasets/ccjs_segmentados_train_test_splits"
EMENDAS_VAR_DATADIR = "../data/refined_datasets/emendas_variadas_segmentadas_train_test_splits"

df_curated_train = datasets.concatenate_datasets([
    datasets.Dataset.load_from_disk(os.path.join(V2_DATADIR, "combined_train_102_parts_4221_instances")),
    datasets.Dataset.load_from_disk(os.path.join(EMENDAS_PREV_DATADIR, "combined_train_9_parts_185_instances")),
    datasets.Dataset.load_from_disk(os.path.join(EMENDAS_VAR_DATADIR, "combined_train_13_parts_261_instances")),
    datasets.Dataset.load_from_disk("./final_curated_dataset_for_hyperparameter_tuning/eval"),
])

df_curated_test_orig = datasets.Dataset.load_from_disk(
    os.path.join(".", "final_curated_dataset_for_hyperparameter_tuning/test")
)

df_curated_test_emendas = datasets.concatenate_datasets([
    datasets.Dataset.load_from_disk(os.path.join(EMENDAS_VAR_DATADIR, "combined_test_10_parts_200_instances")),
    datasets.Dataset.load_from_disk(os.path.join(EMENDAS_PREV_DATADIR, "combined_test_5_parts_103_instances")),
])


def fn_pad(X):
    rem = 1024 - len(X["labels"])
    X["labels"] = X["labels"] + rem * [-100]
    X["input_ids"] = X["input_ids"] + rem * [0]
    X["token_type_ids"] = X["token_type_ids"] + rem * [0]
    X["attention_mask"] = X["attention_mask"] + rem * [0]
    return X


#### delete repeated instances
def rem_repeated_instances(df):
    df_aux = pd.DataFrame(df)
    print(df_aux.shape)
    df_aux_2 = df_aux.applymap(lambda x: " ".join(map(str, x)))
    df_aux_2 = df_aux_2.drop_duplicates(
        subset=["input_ids"], inplace=False, ignore_index=True, keep="last"
    )
    df_aux_2 = df_aux_2.applymap(lambda x: list(map(int, x.split(" "))))
    print("before:", df_aux.shape, "after:", df_aux_2.shape)
    return datasets.Dataset.from_pandas(df_aux_2)


print("train:")
df_curated_train = rem_repeated_instances(df_curated_train)
print("test:")
df_curated_test_orig = rem_repeated_instances(df_curated_test_orig)
df_curated_test_emendas = rem_repeated_instances(df_curated_test_emendas)

df_curated_train = df_curated_train.map(fn_pad, batch_size=32)
df_curated_test_emendas = df_curated_test_emendas.map(fn_pad, batch_size=32)
df_curated_test_orig = df_curated_test_orig.map(fn_pad, batch_size=32)

df_curated_test_emendas = df_curated_test_emendas.cast(df_curated_test_orig.features)

df_curated_train.set_format("torch")
df_curated_test_emendas.set_format("torch")
df_curated_test_orig.set_format("torch")

train:
(6682, 4)
before: (6682, 4) after: (6602, 4)
test:
(2299, 4)
before: (2299, 4) after: (2299, 4)
(303, 4)
before: (303, 4) after: (303, 4)


  0%|          | 0/6602 [00:00<?, ?ex/s]

  0%|          | 0/303 [00:00<?, ?ex/s]

  0%|          | 0/2299 [00:00<?, ?ex/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

In [10]:
batch_size = 6

dl_train = torch.utils.data.DataLoader(df_curated_train, shuffle=True, batch_size=batch_size)
dl_test_emendas = torch.utils.data.DataLoader(df_curated_test_emendas, batch_size=batch_size)
dl_test_orig = torch.utils.data.DataLoader(df_curated_test_orig, batch_size=batch_size)
dl_test_both = torch.utils.data.DataLoader(
    datasets.concatenate_datasets([df_curated_test_emendas, df_curated_test_orig]),
    batch_size=batch_size,
)

# dl_train = torch.utils.data.DataLoader(
#     datasets.concatenate_datasets([
#         df_curated_train,
#         df_curated_test_emendas,
#         df_curated_test_orig,
#     ]),
#     shuffle=True,
#     batch_size=batch_size,
# )
# dl_test_both = None

In [5]:
logit_model = segmentador.BERTSegmenter(
    uri_model=f"../segmenter_checkpoint_v2/2_{VOCAB_SIZE}_layer_model",
    device="cpu",
)

# logit_model = segmentador.BERTSegmenter(
#     uri_model=f"../segmenter_checkpoint_v2/4_{VOCAB_SIZE}_layer_model_finetuned",
#     device="cpu",
# )

# logit_model = segmentador.LSTMSegmenter(
#     uri_model=(
#         os.path.join(
#             f"../segmenter_checkpoint_v2/512_{VOCAB_SIZE}_1_lstm/checkpoints",
#             "512_hidden_dim_6000_vocab_size_1_layer_lstm.pt",
#         )
#     ),
#     device="cpu",
# )

In [20]:
import transformers

layers_num = 2
# aux = torch.load(f"bert_v2_{layers_num}_layers_logit_model_finetuned_emendas.pt", map_location="cpu")
aux = torch.load(f"./checkpoints_2_layers/bert_checkpoint_epoch_5.pt", map_location="cpu")
torch.save(
    aux.state_dict(),
    f"../segmenter_checkpoint_v2/{layers_num}_6000_layer_model_finetuned_emendas/pytorch_model.bin",
)

# aux = torch.load(
#     f"../segmenter_checkpoint_v1/{layers_num}_{VOCAB_SIZE}_layer_model/pytorch_model.bin",
#     map_location="cpu",
# )
# aux

In [192]:
# import segmentador.optimize

# num_layers = 4

# logit_model_finetuned = segmentador.BERTSegmenter(
#     uri_model=f"../segmenter_checkpoint_v2/{num_layers}_{VOCAB_SIZE}_layer_model_finetuned",
#     device="cpu",
# )

# quantized_model_paths = segmentador.optimize.quantize_model(
#     logit_model_finetuned,
#     quantized_model_filename=f"bert_v2_{num_layers}_layers_logit_model_finetuned",
#     model_output_format="onnx",
#     verbose=True,
#     check_cached=False,
# )

# print(quantized_model_paths)

# logit_model_copy = segmentador.optimize.ONNXBERTSegmenter(
#     uri_model=quantized_model_paths.output_uri,
#     uri_tokenizer=f"../tokenizers/{VOCAB_SIZE}_subwords/",
# )._model

In [193]:
# import segmentador.optimize
# import onnxruntime

# hidden_dim = 512

# logit_model_finetuned = segmentador.LSTMSegmenter(
    
#     uri_model=os.path.join(
#         "../segmenter_checkpoint_v2",
#         f"{hidden_dim}_{VOCAB_SIZE}_1_lstm/checkpoints",
#         f"{hidden_dim}_hidden_dim_{VOCAB_SIZE}_vocab_size_1_layer_lstm.pt",
#     ),
#     device="cpu",
# )

# quantized_model_paths = segmentador.optimize.quantize_model(
#     logit_model_finetuned,
#     quantized_model_filename=f"lstm_v2_{hidden_dim}_hidden_dim_logit_model_finetuned",
#     model_output_format="onnx",
#     verbose=True,
#     check_cached=False,
# )

# print(quantized_model_paths)

# logit_model_copy = segmentador.optimize.ONNXLSTMSegmenter(
#     uri_model=quantized_model_paths.output_uri,
#     uri_tokenizer=f"../tokenizers/{VOCAB_SIZE}_subwords/",
# )._model

In [14]:
def eval_model(model, dl_test, device="cuda:0"):
    try: model.eval()
    except AttributeError: pass
    
    if device and (not hasattr(model, "device") or torch.device(device) != model.device):
        try: model.to(device)
        except Exception: device = "cpu"

    fn_precision = torchmetrics.classification.Precision(num_classes=4, average=None).to(device)
    fn_recall = torchmetrics.classification.Recall(num_classes=4, average=None).to(device)
    fn_f1 = torchmetrics.classification.f_beta.F1Score(num_classes=4, average="macro").to(device)

    preds = []
    targets = []

    for batch in tqdm.auto.tqdm(dl_test):
        with torch.no_grad():
            if isinstance(model, (transformers.BertForTokenClassification, optimum.onnxruntime.ORTModelForTokenClassification)):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                logits = outputs.logits

            elif isinstance(model, onnxruntime.InferenceSession):
                logits = model.run(
                    output_names=["logits"],
                    input_feed=dict(input_ids=batch["input_ids"].cpu().numpy()),
                    run_options=None,
                )
                logits = logits[0]
                logits = torch.Tensor(logits)
                
            else:
                outputs = model(batch["input_ids"].to(device))
                logits = outputs["logits"].view(-1, 4)

        predictions = torch.argmax(logits, dim=-1)

        preds.append(predictions.to("cpu"))
        targets.append(batch["labels"].to("cpu"))

    try:
        preds = torch.vstack(preds).view(-1)

    except RuntimeError:
        preds = torch.concat(preds)

    targets = torch.vstack(targets).view(-1)

    preds = torch.tensor(
        [p for i, p in enumerate(preds) if targets[i] != -100], dtype=torch.long
    ).to(device)
    targets = torch.tensor([tg for tg in targets if tg != -100], dtype=torch.long).to(device)

    recall = fn_recall(preds, targets)
    precision = fn_precision(preds, targets)
    f1 = fn_f1(preds, targets)

    return recall.to("cpu"), precision.to("cpu"), f1.to("cpu")

In [7]:
logit_model_copy = copy.deepcopy(logit_model.model)
torch.cuda.empty_cache()

# Setup for final version:
optim = torch.optim.Adam(logit_model_copy.parameters(), lr=5e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.8)
num_epochs = 10
num_training_steps = (num_epochs * len(df_curated_train)) // batch_size
grad_acc_steps = 1 # max(1, 3 // batch_size)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)


# optim = torch.optim.Adam(logit_model_copy.parameters(), lr=5e-5)
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.8)
# num_epochs = 2
# num_training_steps = (num_epochs * len(df_curated_train)) // batch_size
# grad_acc_steps = 1
# loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

In [8]:
pbar = tqdm.auto.tqdm(range(num_training_steps), total=num_training_steps)

logit_model_copy.to("cuda:0")
checkpoint_dir = "./checkpoints_2_layers"


if checkpoint_dir:
    os.makedirs(checkpoint_dir, exist_ok=True)
    

for epoch in range(num_epochs):
    logit_model_copy.train()
    
    for i, batch in enumerate(dl_train, 1):
        if isinstance(logit_model_copy, transformers.BertForTokenClassification):
            batch = {k: v.to("cuda:0") for k, v in batch.items()}
            outputs = logit_model_copy(**batch)
            loss = outputs.loss

        else:
            outputs = logit_model_copy(batch["input_ids"].to("cuda:0"))
            logits = outputs["logits"].view(-1, 4)
            true = batch["labels"].to("cuda:0").view(-1)
            loss = loss_fn(logits, true)

        loss.backward()
        pbar.set_description(f"{float(loss.detach().item()):.6f}")

        if i % grad_acc_steps == 0 or i == len(dl_train):
            optim.step()
            optim.zero_grad()

        pbar.update(1)
    
    if dl_test_both is not None:
        recall_orig, precision_orig, f1_orig = eval_model(logit_model_copy, dl_test_both, device="cuda:0")
        print(f"{recall_orig = }, {precision_orig = }, {f1_orig = }")
    
    if checkpoint_dir:
        if isinstance(logit_model_copy, transformers.BertForTokenClassification):
            torch.save(logit_model_copy, os.path.join(checkpoint_dir, f"bert_checkpoint_epoch_{epoch}.pt"))
        else:
            torch.save(logit_model_copy, os.path.join(checkpoint_dir, f"lstm_checkpoint_epoch_{epoch}.pt"))
    
    lr_scheduler.step()

  0%|          | 0/11003 [00:00<?, ?it/s]

  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9542, 0.8674, 0.6220]), precision_orig = tensor([0.9983, 0.9839, 0.9157, 0.7376]), f1_orig = tensor(0.8834)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9994, 0.9582, 0.8633, 0.6324]), precision_orig = tensor([0.9984, 0.9848, 0.9276, 0.7304]), f1_orig = tensor(0.8856)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9992, 0.9629, 0.8742, 0.6672]), precision_orig = tensor([0.9986, 0.9813, 0.9188, 0.7268]), f1_orig = tensor(0.8906)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9623, 0.8738, 0.6307]), precision_orig = tensor([0.9985, 0.9825, 0.9221, 0.7403]), f1_orig = tensor(0.8874)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9623, 0.8734, 0.6638]), precision_orig = tensor([0.9986, 0.9828, 0.9235, 0.7355]), f1_orig = tensor(0.8917)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9630, 0.8780, 0.6864]), precision_orig = tensor([0.9986, 0.9819, 0.9195, 0.7406]), f1_orig = tensor(0.8955)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9627, 0.8731, 0.6655]), precision_orig = tensor([0.9986, 0.9823, 0.9242, 0.7375]), f1_orig = tensor(0.8922)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9628, 0.8765, 0.6376]), precision_orig = tensor([0.9986, 0.9828, 0.9190, 0.7454]), f1_orig = tensor(0.8890)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9635, 0.8783, 0.6429]), precision_orig = tensor([0.9986, 0.9827, 0.9199, 0.7425]), f1_orig = tensor(0.8899)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_orig = tensor([0.9993, 0.9638, 0.8821, 0.6463]), precision_orig = tensor([0.9986, 0.9824, 0.9166, 0.7361]), f1_orig = tensor(0.8898)


In [303]:
if isinstance(logit_model_copy, transformers.BertForTokenClassification):
    torch.save(logit_model_copy, "bert_v2_4_layers_logit_model_finetuned_emendas_variadas_6_epochs.pt")
else:
    torch.save(logit_model_copy, "lstm_logit_model_finetuned_emendas.pt")

In [296]:
recall_orig, precision_orig, f1_orig = eval_model(logit_model.model, dl_test=dl_test_emendas, device="cuda")
print(f"{recall_orig = }, {precision_orig = }, {f1_orig = }")

recall_orig, precision_orig, f1_orig = eval_model(logit_model.model, dl_test=dl_test_orig, device="cuda")
print(f"{recall_orig = }, {precision_orig = }, {f1_orig = }")

recall_orig, precision_orig, f1_orig = eval_model(logit_model.model, dl_test=dl_test_both, device="cuda")
print(f"{recall_orig = }, {precision_orig = }, {f1_orig = }")

# V2 (V1 + finetuning):
# recall_orig=([0.9970, 0.7688, 0.4450, 0.2556]), precision_orig=([0.9931, 0.8806, 0.6528, 0.3579]), f1_orig=0.6609
# recall_orig=([0.9994, 0.9768, 0.9275, 0.7755]), precision_orig=([0.9991, 0.9861, 0.9279, 0.7533]), f1_orig=0.9182
# recall_orig=([0.9990, 0.9529, 0.8185, 0.6551]), precision_orig=([0.9982, 0.9752, 0.8823, 0.6849]), f1_orig=0.8703

# V2 (from scratch):
# recall_orig=([0.9971, 0.6345, 0.4600, 0.2406]),precision_orig=([0.9908, 0.9314, 0.6345, 0.1475]), f1_orig=0.6162
# recall_orig=([0.9994, 0.9461, 0.7976, 0.5397]),precision_orig=([0.9979, 0.9837, 0.9260, 0.7062]), f1_orig=0.8580
# recall_orig=([0.9990, 0.9102, 0.7213, 0.4704]),precision_orig=([0.9967, 0.9793, 0.8685, 0.4874]), f1_orig=0.8020

  0%|          | 0/101 [00:00<?, ?it/s]

recall_orig = tensor([0.9970, 0.7688, 0.4450, 0.2556]), precision_orig = tensor([0.9931, 0.8806, 0.6528, 0.3579]), f1_orig = tensor(0.6609)


  0%|          | 0/767 [00:00<?, ?it/s]

recall_orig = tensor([0.9994, 0.9768, 0.9275, 0.7755]), precision_orig = tensor([0.9991, 0.9861, 0.9279, 0.7533]), f1_orig = tensor(0.9182)


  0%|          | 0/868 [00:00<?, ?it/s]

recall_orig = tensor([0.9990, 0.9529, 0.8185, 0.6551]), precision_orig = tensor([0.9982, 0.9752, 0.8823, 0.6849]), f1_orig = tensor(0.8703)


In [9]:
for epoch in [5]:
    print(f"{epoch = }")
    
    torch.cuda.empty_cache()
    logit_model_copy = torch.load(os.path.join(checkpoint_dir, f"bert_checkpoint_epoch_{epoch}.pt"))
    
    recall_new, precision_new, f1_new = eval_model(logit_model_copy, dl_test=dl_test_emendas, device="cuda")
    print(f"{recall_new = }, {precision_new = }, {f1_new = }")

    recall_new, precision_new, f1_new = eval_model(logit_model_copy, dl_test=dl_test_orig, device="cuda")
    print(f"{recall_new = }, {precision_new = }, {f1_new = }")

    recall_new, precision_new, f1_new = eval_model(logit_model_copy, dl_test=dl_test_both, device="cuda")
    print(f"{recall_new = }, {precision_new = }, {f1_new = }")

# V1:
# 2-BERT v1: recall_orig = ([0.9994, 0.9389, 0.7762, 0.4331]), precision_orig = ([0.9976, 0.9858, 0.8991, 0.6920])
# f1 = 0.8369, f1_c1= 0.9618
# 4-BERT v1: recall_orig = ([0.9995, 0.9394, 0.7859, 0.4830]), precision_orig = ([0.9976, 0.9867, 0.9104, 0.6741])
# f1 = 0.8447, f1_c1= 0.9625
# 6-BERT v1: recall_orig = ([0.9994, 0.9390, 0.7835, 0.4444]), precision_orig = ([0.9976, 0.9864, 0.9076, 0.6405])
# f1 = 0.8316, f1_c1= 0.9621
# 128-LSTMv1:recall_orig = ([0.9994, 0.9386, 0.7800, 0.5102]), precision_orig = ([0.9976, 0.9861, 0.9092, 0.6839])
# f1 = 0.8484, f1_c1= 0.9618
# 256-LSTMv1:recall_orig = ([0.9994, 0.9376, 0.7849, 0.5193]), precision_orig = ([0.9976, 0.9859, 0.9016, 0.6716])
# f1 = 0.8479, f1_c1= 0.9611
# 512-LSTMv1:recall_orig = ([0.9994, 0.9375, 0.7849, 0.5238]), precision_orig = ([0.9976, 0.9863, 0.9087, 0.6453])
# f1 = 0.8464, f1_c1= 0.9613

# V2 from scratch:
# 2-BERT v2: recall_orig = ([0.9994, 0.9468, 0.7976, 0.5079]), precision_orig = ([0.9979, 0.9834, 0.9255, 0.6978])
# f1 = 0.8548, f1_c1= 0.9648
# 4-BERT v2: recall_orig = ([0.9994, 0.9461, 0.7976, 0.5397]), precision_orig = ([0.9979, 0.9837, 0.9260, 0.7062])
# f1 = 0.8601, f1_c1= 0.9645
# 128-LSTMv2:recall_orig = ([0.9993, 0.9603, 0.8224, 0.5833]), precision_orig = ([0.9983, 0.9842, 0.9069, 0.6960])
# f1 = 0.8680, f1_c1= 0.9721
# 128QLSTMv2:recall_new  = ([0.9994, 0.9599, 0.8214, 0.5351]), precision_new  = ([0.9983, 0.9862, 0.9219, 0.7195])
# f1 = 0.8660, f1_c1= 0.9729
# 256-LSTMv2:recall_orig = ([0.9994, 0.9620, 0.8202, 0.5646]), precision_orig = ([0.9983, 0.9855, 0.9171, 0.7137])
# f1 = 0.8688, f1_c1= 0.9736
# 256QLSTMv2:recall_new  = ([0.9994, 0.9612, 0.8219, 0.5646]), precision_new  = ([0.9983, 0.9856, 0.9169, 0.7074])
# f1 = 0.8682, f1_c1= 0.9732
# 512-LSTMv2:recall_orig = ([0.9994, 0.9583, 0.8215, 0.5352]), precision_orig = ([0.9982, 0.9866, 0.9133, 0.7250])
# f1 = 0.8655, f1_c1= 0.9722
# 512QLSTMv2:recall_new  = ([0.9994, 0.9577, 0.8248, 0.5442]), precision_new  = ([0.9982, 0.9864, 0.9133, 0.7229])
# f1 = 0.8668, f1_c1= 0.9718

# V2 with post-training fine-tuning:
# 2-BERT v2: recall_new = ([0.9993, 0.9659, 0.8491, 0.5601]), precision_new = ([0.9985, 0.9797, 0.9228, 0.7508])
# f1 = 0.8769, f1_c1= 0.9728
# 4-BERT v2: recall_new = ([0.9994, 0.9669, 0.8871, 0.5283]), precision_new = ([0.9986, 0.9863, 0.9230, 0.7793])
# f1 = 0.8820, f1_c1= 0.9765
# 4-BERTe3v2:recall_new = ([0.9995, 0.9694, 0.8939, 0.6871]), precision_new = ([0.9988, 0.9895, 0.9320, 0.7537])
# f1 = 0.9027, f1_c1= 0.9793
# 4-BERTe5v2:recall_new = ([0.9994, 0.9745, 0.9061, 0.7506]), precision_new = ([0.9990, 0.9862, 0.9366, 0.7609])
# f1 = 0.9141, f1_c1= 0.9803
#2-BERTe10v2:recall_new = ([0.9993, 0.9729, 0.9051, 0.6916]), precision_new = ([0.9989, 0.9856, 0.9172, 0.7531])
# f1 = 0.9028, f1_c1= 0.9792
#4-BERTe10v2:recall_new = ([0.9994, 0.9768, 0.9275, 0.7755]), precision_new = ([0.9991, 0.9861, 0.9279, 0.7533])
# f1 = 0.9182, f1_c1= 0.9814
#6-BERTe10v2:recall_new = ([0.9994, 0.9753, 0.9221, 0.7188]), precision_new = ([0.9990, 0.9875, 0.9280, 0.7441])
# f1 = 0.9092, f1_c1= 0.9814
#2QBERTe10v2:recall_new = ([0.9993, 0.9729, 0.9051, 0.6916]), precision_new = ([0.9989, 0.9856, 0.9176, 0.7512])
# f1 = 0.9027, f1_c1= 0.9792
#4QBERTe10v2:recall_new = ([0.9993, 0.9769, 0.9275, 0.7755]), precision_new = ([0.9991, 0.9859, 0.9284, 0.7533])
# f1 = 0.9182, f1_c1= 0.9814
#6QBERTe10v2:recall_new = ([0.9994, 0.9754, 0.9217, 0.7234]), precision_new = ([0.9990, 0.9876, 0.9271, 0.7419])
# f1 = 0.9094, f1_c1= 0.9815
# 256-LSTMv2:recall_new = ([0.9992, 0.9632, 0.8297, 0.4807]), precision_new = ([0.9984, 0.9748, 0.9337, 0.7518])
# f1 = 0.8638, f1_c1= 0.9690
# 512-LSTMv2:recall_new = ([0.9993, 0.9644, 0.8477, 0.4898]), precision_new = ([0.9985, 0.9762, 0.9376, 0.8030])
# f1 = 0.8740, f1_c1= 0.9703


## EMENDAS
# v2 reference:
# recall_orig = tensor([0.9989, 0.8602, 0.0957, 0.0566]), precision_orig = tensor([0.9898, 0.9418, 0.6207, 1.0000]), f1_orig = tensor(0.5416)
# recall_orig = tensor([0.9994, 0.9768, 0.9275, 0.7755]), precision_orig = tensor([0.9991, 0.9861, 0.9279, 0.7533]), f1_orig = tensor(0.9182)

# From v1 (10 epochs, bsize= 3 * 2 grad_acc)
# recall_new = tensor([0.9998, 0.9862, 0.9362, 0.8302]), precision_new = tensor([0.9990, 0.9960, 0.9724, 0.9778]), f1_new = tensor(0.9606)
# recall_new = tensor([0.9994, 0.9752, 0.9217, 0.7755]), precision_new = tensor([0.9991, 0.9872, 0.9262, 0.7720]), f1_new = tensor(0.9195)

# (post-training with all active learning data)
# 3 epochs
# recall_new = tensor([0.9994, 0.9902, 0.9521, 0.7170]), precision_new = tensor([0.9991, 0.9824, 0.9471, 0.9500]), f1_new = tensor(0.9381)
# recall_new = tensor([0.9993, 0.9769, 0.9226, 0.7188]), precision_new = tensor([0.9991, 0.9844, 0.9186, 0.7512]), f1_new = tensor(0.9088)

# (post-training with only emendas data)
# 1 epoch
# recall_new = tensor([0.9996, 0.9862, 0.9149, 0.6792]), precision_new = tensor([0.9986, 0.9921, 0.9609, 0.9231]), f1_new = tensor(0.9270)
# recall_new = tensor([0.9991, 0.9729, 0.9041, 0.8458]), precision_new = tensor([0.9990, 0.9889, 0.9068, 0.5173]), f1_new = tensor(0.8818)

# 3 epochs
# recall_new = tensor([0.9997, 0.9902, 0.9096, 0.8302]), precision_new = tensor([0.9989, 0.9941, 0.9716, 0.9565]), f1_new = tensor(0.9550)
# recall_new = tensor([0.9991, 0.9726, 0.8998, 0.8390]), precision_new = tensor([0.9990, 0.9885, 0.9086, 0.5293]), f1_new = tensor(0.8832)


## EMENDAS VARIADAS
# From V2 (V1 + finetuning 10 epochs):
# recall_new = tensor([0.9986, 0.9076, 0.8083, 0.6165]), precision_new = tensor([0.9973, 0.9455, 0.8883, 0.7664]), f1_new = tensor(0.8635)
# recall_new = tensor([0.9993, 0.9771, 0.9270, 0.7551]), precision_new = tensor([0.9991, 0.9858, 0.9248, 0.7603]), f1_new = tensor(0.9160)
# recall_new = tensor([0.9992, 0.9691, 0.9002, 0.7230]), precision_new = tensor([0.9988, 0.9813, 0.9171, 0.7615]), f1_new = tensor(0.9061)


# epoch = 0
# recall_new = tensor([0.9982, 0.8074, 0.6650, 0.3083]), precision_new = tensor([0.9947, 0.8995, 0.8750, 0.9111]), f1_new = tensor(0.7660)
# recall_new = tensor([0.9992, 0.9705, 0.8725, 0.6168]), precision_new = tensor([0.9987, 0.9761, 0.9358, 0.7514]), f1_new = tensor(0.8882)
# recall_new = tensor([0.9990, 0.9517, 0.8256, 0.5453]), precision_new = tensor([0.9981, 0.9680, 0.9241, 0.7690]), f1_new = tensor(0.8671)

# epoch = 1
# recall_new = tensor([0.9986, 0.8304, 0.7600, 0.3910]), precision_new = tensor([0.9955, 0.9419, 0.8555, 0.7647]), f1_new = tensor(0.8005)
# recall_new = tensor([0.9993, 0.9683, 0.9163, 0.6984]), precision_new = tensor([0.9988, 0.9881, 0.9066, 0.7130]), f1_new = tensor(0.8986)
# recall_new = tensor([0.9992, 0.9524, 0.8810, 0.6272]), precision_new = tensor([0.9983, 0.9833, 0.8962, 0.7200]), f1_new = tensor(0.8813)

# epoch = 2
# recall_new = tensor([0.9986, 0.8946, 0.7250, 0.4962]), precision_new = tensor([0.9967, 0.9320, 0.9197, 0.8571]), f1_new = tensor(0.8375)
# recall_new = tensor([0.9994, 0.9741, 0.8852, 0.7075]), precision_new = tensor([0.9989, 0.9839, 0.9396, 0.7685]), f1_new = tensor(0.9066)
# recall_new = tensor([0.9993, 0.9649, 0.8490, 0.6585]), precision_new = tensor([0.9985, 0.9781, 0.9357, 0.7826]), f1_new = tensor(0.8940)

# epoch = 3
# recall_new = tensor([0.9978, 0.9058, 0.7783, 0.5865]), precision_new = tensor([0.9972, 0.9027, 0.8680, 0.7358]), f1_new = tensor(0.8438)
# recall_new = tensor([0.9992, 0.9768, 0.9182, 0.7823]), precision_new = tensor([0.9991, 0.9817, 0.9241, 0.7566]), f1_new = tensor(0.9172)
# recall_new = tensor([0.9990, 0.9686, 0.8866, 0.7369]), precision_new = tensor([0.9988, 0.9726, 0.9124, 0.7527]), f1_new = tensor(0.9034)

# epoch = 4
# recall_new = tensor([0.9984, 0.8913, 0.8000, 0.6842]), precision_new = tensor([0.9971, 0.9420, 0.8791, 0.6691]), f1_new = tensor(0.8570)
# recall_new = tensor([0.9994, 0.9747, 0.9187, 0.7914]), precision_new = tensor([0.9990, 0.9867, 0.9232, 0.7505]), f1_new = tensor(0.9178)
# recall_new = tensor([0.9992, 0.9651, 0.8919, 0.7666]), precision_new = tensor([0.9987, 0.9817, 0.9139, 0.7321]), f1_new = tensor(0.9060)

# epoch = 5
# recall_new = tensor([0.9982, 0.9165, 0.7800, 0.5714]), precision_new = tensor([0.9974, 0.9186, 0.8948, 0.7917]), f1_new = tensor(0.8531)
# recall_new = tensor([0.9993, 0.9772, 0.9153, 0.7732]), precision_new = tensor([0.9991, 0.9829, 0.9271, 0.7629]), f1_new = tensor(0.9171)
# recall_new = tensor([0.9991, 0.9702, 0.8847, 0.7265]), precision_new = tensor([0.9988, 0.9755, 0.9205, 0.7680]), f1_new = tensor(0.9052)


# From V2 (from scratch):
# recall_new = tensor([0.9990, 0.9039, 0.8150, 0.6692]), precision_new = tensor([0.9974, 0.9606, 0.8940, 0.8241]), f1_new = tensor(0.8802)
# recall_new = tensor([0.9995, 0.9728, 0.9163, 0.7778]), precision_new = tensor([0.9990, 0.9890, 0.9280, 0.7725]), f1_new = tensor(0.9193)
# recall_new = tensor([0.9994, 0.9649, 0.8934, 0.7526]), precision_new = tensor([0.9987, 0.9858, 0.9208, 0.7826]), f1_new = tensor(0.9121)


epoch = 5


  0%|          | 0/51 [00:00<?, ?it/s]

recall_new = tensor([0.9986, 0.8950, 0.7900, 0.5639]), precision_new = tensor([0.9971, 0.9404, 0.8893, 0.7009]), f1_new = tensor(0.8442)


  0%|          | 0/384 [00:00<?, ?it/s]

recall_new = tensor([0.9994, 0.9718, 0.9036, 0.7234]), precision_new = tensor([0.9989, 0.9871, 0.9276, 0.7506]), f1_new = tensor(0.9077)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_new = tensor([0.9993, 0.9630, 0.8780, 0.6864]), precision_new = tensor([0.9986, 0.9819, 0.9195, 0.7406]), f1_new = tensor(0.8955)


In [16]:
import torch


def fn(logit_model_copy):    
    recall_new, precision_new, f1_new = eval_model(logit_model_copy, dl_test=dl_test_emendas, device="cuda")
    print(f"{recall_new = }, {precision_new = }, {f1_new = }")

    recall_new, precision_new, f1_new = eval_model(logit_model_copy, dl_test=dl_test_orig, device="cuda")
    print(f"{recall_new = }, {precision_new = }, {f1_new = }")

    recall_new, precision_new, f1_new = eval_model(logit_model_copy, dl_test=dl_test_both, device="cuda")
    print(f"{recall_new = }, {precision_new = }, {f1_new = }")
    

print("4-BERT-emendas-finetuning")
logit_model = segmentador.BERTSegmenter("../segmenter_checkpoint_v2/4_6000_layer_model_finetuned_emendas_old").model
fn(logit_model)

print("4-BERT-emendas-from-sratch")
logit_model = segmentador.BERTSegmenter("../segmenter_checkpoint_v2/4_6000_layer_model_finetuned_emendas").model
fn(logit_model)
    
print("256-LSTM-emendas")
logit_model = segmentador.LSTMSegmenter(
    uri_model="../segmenter_checkpoint_v2/256_6000_1_lstm_emendas/checkpoints/epoch=9-step=12489.ckpt",
    device="cpu",
).model
fn(logit_model)

print("512-LSTM")
logit_model = segmentador.LSTMSegmenter(
    uri_model=(
        "../segmenter_checkpoint_v2/512_6000_1_lstm/checkpoints/512_hidden_dim_6000_vocab_size_1_layer_lstm.pt"
    ),
    device="cpu",
).model
fn(logit_model)

4-BERT-emendas-from-scratch


  0%|          | 0/51 [00:00<?, ?it/s]

recall_new = tensor([0.9976, 0.8037, 0.6917, 0.5865]), precision_new = tensor([0.9950, 0.9029, 0.8012, 0.6783]), f1_new = tensor(0.8045)


  0%|          | 0/384 [00:00<?, ?it/s]

recall_new = tensor([0.9994, 0.9752, 0.9217, 0.7755]), precision_new = tensor([0.9991, 0.9872, 0.9262, 0.7720]), f1_new = tensor(0.9195)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_new = tensor([0.9991, 0.9554, 0.8697, 0.7317]), precision_new = tensor([0.9984, 0.9784, 0.9009, 0.7527]), f1_new = tensor(0.8981)
4-BERT-emendas-finetuning


  0%|          | 0/51 [00:00<?, ?it/s]

recall_new = tensor([0.9989, 0.9065, 0.8133, 0.6617]), precision_new = tensor([0.9975, 0.9506, 0.9071, 0.8224]), f1_new = tensor(0.8793)


  0%|          | 0/384 [00:00<?, ?it/s]

recall_new = tensor([0.9994, 0.9746, 0.9187, 0.8027]), precision_new = tensor([0.9991, 0.9886, 0.9232, 0.7746]), f1_new = tensor(0.9226)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_new = tensor([0.9994, 0.9667, 0.8949, 0.7700]), precision_new = tensor([0.9988, 0.9844, 0.9199, 0.7837]), f1_new = tensor(0.9146)
256-LSTM-emendas


  0%|          | 0/51 [00:00<?, ?it/s]

recall_new = tensor([0.9982, 0.8568, 0.7967, 0.6241]), precision_new = tensor([0.9965, 0.9367, 0.8835, 0.4716]), f1_new = tensor(0.8168)


  0%|          | 0/384 [00:00<?, ?it/s]

recall_new = tensor([0.9993, 0.9595, 0.8253, 0.5828]), precision_new = tensor([0.9983, 0.9841, 0.9108, 0.6946]), f1_new = tensor(0.8675)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_new = tensor([0.9991, 0.9476, 0.8188, 0.5923]), precision_new = tensor([0.9980, 0.9789, 0.9047, 0.6227]), f1_new = tensor(0.8571)
512-LSTM


  0%|          | 0/51 [00:00<?, ?it/s]

recall_new = tensor([0.9980, 0.6694, 0.5017, 0.1203]), precision_new = tensor([0.9914, 0.9401, 0.7359, 0.1311]), f1_new = tensor(0.6247)


  0%|          | 0/384 [00:00<?, ?it/s]

recall_new = tensor([0.9994, 0.9575, 0.8243, 0.5442]), precision_new = tensor([0.9982, 0.9864, 0.9137, 0.7229]), f1_new = tensor(0.8646)


  0%|          | 0/434 [00:00<?, ?it/s]

recall_new = tensor([0.9992, 0.9243, 0.7514, 0.4460]), precision_new = tensor([0.9971, 0.9823, 0.8816, 0.5639]), f1_new = tensor(0.8150)
