In [1]:
import os
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from transformers import (
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
    pipeline
)

# Evaluation
from seqeval.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    classification_report
)

# Configuration
warnings.filterwarnings('ignore')
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [2]:
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


GPU: NVIDIA GeForce RTX 3060 Ti
Memory: 8.6 GB


In [3]:
# Path configuration
JSONL_DIR = Path(r"data\jsonls")
LABELS_TXT = r"data\labels.txt"

# Model configuration
MODEL_CHECKPOINT = "dccuchile/bert-base-spanish-wwm-cased"
OUTPUT_DIR = r"models\beto-ner"

# Data split ratios
VAL_SPLIT  = 0.10  
TEST_SPLIT = 0.15       

# Tokenization parameters
MAX_LENGTH = 512
DOC_STRIDE = 128

SEED = 42

In [4]:
import json
import random
from pathlib import Path
from collections import defaultdict
from datasets import Dataset, DatasetDict
import shutil

# Load labels and create BIO tags
with open(LABELS_TXT, "r", encoding="utf-8") as f:
    base_labels = [line.strip() for line in f if line.strip()]

base_labels = [l for l in base_labels if l != "O"]

bio_labels = ["O"]
for lab in base_labels:
    bio_labels.append(f"B-{lab}")
    bio_labels.append(f"I-{lab}")

label2id = {l: i for i, l in enumerate(bio_labels)}
id2label = {i: l for l, i in label2id.items()}

# Read JSONL files organized by company (subfolders)
records = []
by_company_by_doc = defaultdict(lambda: defaultdict(list))

company_dirs = [p for p in sorted(JSONL_DIR.iterdir()) if p.is_dir()]

for company_dir in company_dirs:
    company = company_dir.name
    for jsonl_file in sorted(company_dir.glob("*.jsonl")):
        with open(jsonl_file, "r", encoding="utf-8") as f:
            for line in f:
                r = json.loads(line)
                r["entities"] = r.get("entities") or []
                r["doc_id"] = r.get("doc_id") or jsonl_file.stem
                r["company"] = r.get("company") or company

                records.append(r)
                by_company_by_doc[company][r["doc_id"]].append(r)

print(f"Companies found: {len(company_dirs)}")
print("Total records:", len(records))

# Document-level split by company (prevents data leakage)
train_recs, val_recs, test_recs = [], [], []
train_docs, val_docs, test_docs = set(), set(), set()

rng = random.Random(SEED)

for company, by_doc in by_company_by_doc.items():
    doc_ids = list(by_doc.keys())
    rng.shuffle(doc_ids)

    n = len(doc_ids)
    n_val  = int(n * VAL_SPLIT)
    n_test = int(n * TEST_SPLIT)

    # Ensure minimum samples if enough documents
    if n >= 3:
        n_val  = max(n_val, 1)
        n_test = max(n_test, 1)

    # Ensure at least 1 document in train
    if n_val + n_test > n - 1:
        overflow = (n_val + n_test) - (n - 1)
        cut_test = min(overflow, n_test)
        n_test -= cut_test
        overflow -= cut_test
        if overflow > 0:
            n_val = max(0, n_val - overflow)

    val_doc_ids   = set(doc_ids[:n_val])
    test_doc_ids  = set(doc_ids[n_val:n_val + n_test])
    train_doc_ids = set(doc_ids[n_val + n_test:])

    val_docs   |= val_doc_ids
    test_docs  |= test_doc_ids
    train_docs |= train_doc_ids

    val_recs   += [r for d in val_doc_ids   for r in by_doc[d]]
    test_recs  += [r for d in test_doc_ids  for r in by_doc[d]]
    train_recs += [r for d in train_doc_ids for r in by_doc[d]]

print(f"\nTOTAL DOCS: Train: {len(train_docs)} | Val: {len(val_docs)} | Test: {len(test_docs)}")
print(f"TOTAL RECS: Train: {len(train_recs)} | Val: {len(val_recs)} | Test: {len(test_recs)}")
print("Train/val overlap:", len(train_docs & val_docs))
print("Train/test overlap:", len(train_docs & test_docs))
print("Val/test overlap:", len(val_docs & test_docs))

# Build HuggingFace DatasetDict
raw_datasets = DatasetDict({
    "train":      Dataset.from_list(train_recs),
    "validation": Dataset.from_list(val_recs),
    "test":       Dataset.from_list(test_recs),
})

# Copy test documents for later inference
rawtxts_dir = Path("../data/rawtexts")
dst_dir = Path("data/testinvoicesjson")
dst_dir.mkdir(parents=True, exist_ok=True)

test_doc_ids = {r["doc_id"] for r in test_recs}
moved = 0

for txt_file in rawtxts_dir.rglob("*.txt"):
    if txt_file.stem in test_doc_ids:
        shutil.copy2(txt_file, dst_dir / txt_file.name)
        moved += 1

Companies found: 6
Total records: 166

TOTAL DOCS: Train: 128 | Val: 16 | Test: 22
TOTAL RECS: Train: 128 | Val: 16 | Test: 22
Train/val overlap: 0
Train/test overlap: 0
Val/test overlap: 0


In [5]:
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_CHECKPOINT,
    use_fast=True,
)
print(f"Tokenizer: {MODEL_CHECKPOINT} (vocab: {tokenizer.vocab_size:,})")

def char_spans_to_marks(text, entities):
    """
    Map character positions to entity indices.
    Handles overlaps by prioritizing longer spans.
    """
    marks = [None] * len(text)
    sorted_ents = sorted(entities, key=lambda e: (-(e["end"]-e["start"]), e["start"]))
    for idx, ent in enumerate(sorted_ents):
        s, e = ent["start"], ent["end"]
        for i in range(max(0, s), min(len(text), e)):
            if marks[i] is None:
                marks[i] = idx
    return marks, sorted_ents

def tokenize_and_align_v3(batch):
    """
    Tokenize texts and align character-level entity spans to subword tokens.
    Uses BIO tagging scheme with sliding windows for long documents.
    """
    texts = batch["text"]
    ents_list = batch.get("entities", [[]] * len(texts))

    out = {
        "input_ids": [],
        "attention_mask": [],
        "labels": [],
        "overflow_to_sample_mapping": [],
    }

    for text, entities in zip(texts, ents_list):
        marks, sorted_ents = char_spans_to_marks(text, entities)

        enc = tokenizer(
            text,
            return_offsets_mapping=True,
            max_length=MAX_LENGTH,
            truncation=True,
            stride=DOC_STRIDE,
            return_overflowing_tokens=True,
        )

        otm = enc.get("overflow_to_sample_mapping", [0] * len(enc["input_ids"]))

        for i in range(len(enc["input_ids"])):
            offsets = enc["offset_mapping"][i]
            lbl_ids = [-100] * len(offsets)
            prev_ent_idx = None

            for j, (start, end) in enumerate(offsets):
                if start == end:
                    continue

                ent_idx = None
                for k in range(start, end):
                    if 0 <= k < len(marks) and marks[k] is not None:
                        ent_idx = marks[k]
                        break

                if ent_idx is None:
                    lbl_ids[j] = label2id["O"]
                else:
                    lab = sorted_ents[ent_idx]["label"]  # ie: "EAN"
                    tag = ("B-" if prev_ent_idx != ent_idx else "I-") + lab  # "B-EAN"/"I-EAN"

                    lbl_ids[j] = label2id[tag]

                prev_ent_idx = ent_idx

            out["input_ids"].append(enc["input_ids"][i])
            out["attention_mask"].append(enc["attention_mask"][i])
            out["labels"].append(lbl_ids)
            out["overflow_to_sample_mapping"].append(otm[i])

    return out

tokenized_datasets = raw_datasets.map(
    tokenize_and_align_v3,
    batched=True,
    batch_size=10,
    remove_columns=raw_datasets["train"].column_names,
)

tokenized_datasets

Tokenizer: dccuchile/bert-base-spanish-wwm-cased (vocab: 31,002)


Map:   0%|          | 0/128 [00:00<?, ? examples/s]

Map:   0%|          | 0/16 [00:00<?, ? examples/s]

Map:   0%|          | 0/22 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'overflow_to_sample_mapping'],
        num_rows: 882
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'overflow_to_sample_mapping'],
        num_rows: 88
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'overflow_to_sample_mapping'],
        num_rows: 130
    })
})

In [6]:
collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)

def compute_metrics(p):
    """
    Compute entity-level metrics using seqeval.
    Returns micro/macro averaged precision, recall, F1, and per-entity metrics.
    """
    preds = np.argmax(p.predictions, axis=-1)
    labels = p.label_ids

    # Filter out special tokens (-100)
    true_predictions, true_labels = [], []
    for pred, lab in zip(preds, labels):
        cur_pred, cur_lab = [], []
        for p_i, l_i in zip(pred, lab):
            if l_i == -100:
                continue
            cur_pred.append(id2label[int(p_i)])
            cur_lab.append(id2label[int(l_i)])
        true_predictions.append(cur_pred)
        true_labels.append(cur_lab)

    # Global metrics
    metrics = {
        "precision": precision_score(true_labels, true_predictions),
        "recall":    recall_score(true_labels, true_predictions),
        "f1":        f1_score(true_labels, true_predictions),
        "precision_macro": precision_score(true_labels, true_predictions, average="macro"),
        "recall_macro":    recall_score(true_labels, true_predictions, average="macro"),
        "f1_macro":        f1_score(true_labels, true_predictions, average="macro"),
        "accuracy":  accuracy_score(true_labels, true_predictions),
    }

    # Per-entity metrics
    rep = classification_report(true_labels, true_predictions, output_dict=True, zero_division=0)
    
    # Print detailed report only for final evaluation
    if len(true_labels) > 100:  # Heuristic: val/test sets are larger
        print("\n" + classification_report(true_labels, true_predictions, zero_division=0))

    for ent, vals in rep.items():
        if ent in {"micro avg", "macro avg", "weighted avg"}:
            continue
        if isinstance(vals, dict) and "f1-score" in vals:
            metrics[f"precision_{ent}"] = float(vals["precision"])
            metrics[f"recall_{ent}"]    = float(vals["recall"])
            metrics[f"f1_{ent}"]        = float(vals["f1-score"])
            metrics[f"support_{ent}"]   = float(vals["support"])

    return metrics

In [7]:
num_labels = len(label2id)

model = AutoModelForTokenClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

print(f"✓ Model: {MODEL_CHECKPOINT}")
print(f"  Parameters: {model.num_parameters():,}")
print(f"  Labels: {num_labels}")


Some weights of BertForTokenClassification were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Model: dccuchile/bert-base-spanish-wwm-cased
  Parameters: 109,271,823
  Labels: 15


In [8]:
args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=1.5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    num_train_epochs=10,
    weight_decay=0.01,
    warmup_ratio=0.15,
    logging_steps=25,
    logging_dir=os.path.join(OUTPUT_DIR, "logs"),
    fp16=torch.cuda.is_available(),
    seed=SEED,
    data_seed=SEED,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    report_to="none",
    label_smoothing_factor=0.1,
)

print(f"Learning rate:      {args.learning_rate}")
print(f"Epochs:             {args.num_train_epochs}")
print(f"Batch size:         {args.per_device_train_batch_size * args.gradient_accumulation_steps} (effective)")
print(f"Label smoothing:    {args.label_smoothing_factor}")
print(f"Device:             {'GPU (FP16)' if torch.cuda.is_available() else 'CPU'}")
print("="*60)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=4)],
    data_collator=collator,
    compute_metrics=compute_metrics,
)

train_result = trainer.train()

print(f"Best F1 score: {trainer.state.best_metric:.4f}")
print(f"Training time: {train_result.metrics.get('train_runtime', 0):.2f}s")
print(f"Samples/second: {train_result.metrics.get('train_samples_per_second', 0):.2f}")


Learning rate:      1.5e-05
Epochs:             10
Batch size:         16 (effective)
Label smoothing:    0.1
Device:             GPU (FP16)


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Precision Macro,Recall Macro,F1 Macro,Accuracy,Precision Cantidad,Recall Cantidad,F1 Cantidad,Support Cantidad,Precision Ean,Recall Ean,F1 Ean,Support Ean,Precision Fecha,Recall Fecha,F1 Fecha,Support Fecha,Precision Nombre Producto,Recall Nombre Producto,F1 Nombre Producto,Support Nombre Producto,Precision Numero Factura,Recall Numero Factura,F1 Numero Factura,Support Numero Factura,Precision Precio Coste Unidad,Recall Precio Coste Unidad,F1 Precio Coste Unidad,Support Precio Coste Unidad,Precision Sku,Recall Sku,F1 Sku,Support Sku
1,1.492,0.877032,0.172516,0.147382,0.158962,0.206441,0.104544,0.113889,0.884911,1.0,0.185547,0.313015,512.0,0.078846,0.136667,0.1,300.0,0.0,0.0,0.0,16.0,0.312925,0.349146,0.330045,527.0,0.0,0.0,0.0,16.0,0.010929,0.003846,0.00569,520.0,0.042386,0.056604,0.048474,477.0
2,0.643,0.620605,0.932872,0.968328,0.950269,0.66848,0.701154,0.68433,0.984129,0.977055,0.998047,0.98744,512.0,0.960912,0.983333,0.971993,300.0,0.0,0.0,0.0,16.0,0.898955,0.979127,0.93733,527.0,0.0,0.0,0.0,16.0,0.92196,0.976923,0.948646,520.0,0.920477,0.97065,0.944898,477.0
3,0.5846,0.602716,0.944943,0.978463,0.961411,0.744376,0.786081,0.762013,0.989142,0.978927,0.998047,0.988395,512.0,0.967427,0.99,0.978583,300.0,0.333333,0.5,0.4,16.0,0.948624,0.981025,0.964552,527.0,0.1,0.0625,0.076923,16.0,0.950459,0.996154,0.97277,520.0,0.931864,0.974843,0.952869,477.0
4,0.5747,0.598787,0.961097,0.991132,0.975884,0.905174,0.932964,0.918781,0.990235,0.982692,0.998047,0.99031,512.0,0.967427,0.99,0.978583,300.0,0.764706,0.8125,0.787879,16.0,0.96488,0.990512,0.977528,527.0,0.75,0.75,0.75,16.0,0.94708,0.998077,0.97191,520.0,0.959432,0.991614,0.975258,477.0
5,0.5688,0.592341,0.972303,0.993243,0.982661,0.94289,0.943042,0.942627,0.991565,0.982726,1.0,0.991288,512.0,0.970588,0.99,0.980198,300.0,0.866667,0.8125,0.83871,16.0,0.972119,0.99241,0.98216,527.0,0.866667,0.8125,0.83871,16.0,0.970093,0.998077,0.983886,520.0,0.97137,0.995807,0.983437,477.0
6,0.5683,0.59659,0.969174,0.995777,0.982295,0.960017,0.970645,0.965046,0.991162,0.982726,1.0,0.991288,512.0,0.970588,0.99,0.980198,300.0,0.9375,0.9375,0.9375,16.0,0.972222,0.996205,0.984067,527.0,0.933333,0.875,0.903226,16.0,0.952381,1.0,0.97561,520.0,0.97137,0.995807,0.983437,477.0
7,0.5667,0.595048,0.973999,0.996622,0.985181,0.963211,0.971445,0.967086,0.991803,0.982726,1.0,0.991288,512.0,0.970684,0.993333,0.981878,300.0,0.9375,0.9375,0.9375,16.0,0.977654,0.996205,0.986842,527.0,0.933333,0.875,0.903226,16.0,0.961111,0.998077,0.979245,520.0,0.979466,1.0,0.989627,477.0
8,0.5646,0.593994,0.978838,0.996199,0.987442,0.975598,0.979626,0.977408,0.991946,0.982726,1.0,0.991288,512.0,0.970588,0.99,0.980198,300.0,0.9375,0.9375,0.9375,16.0,0.975791,0.994307,0.984962,527.0,1.0,0.9375,0.967742,16.0,0.981096,0.998077,0.989514,520.0,0.981481,1.0,0.990654,477.0
9,0.5631,0.593813,0.97684,0.997466,0.987046,0.965033,0.972192,0.968391,0.992278,0.982726,1.0,0.991288,512.0,0.970779,0.996667,0.983553,300.0,0.9375,0.9375,0.9375,16.0,0.979516,0.998102,0.988722,527.0,0.933333,0.875,0.903226,16.0,0.97191,0.998077,0.98482,520.0,0.979466,1.0,0.989627,477.0
10,0.5632,0.594163,0.974412,0.997044,0.985598,0.963466,0.971921,0.967453,0.991946,0.980843,1.0,0.990329,512.0,0.970779,0.996667,0.983553,300.0,0.9375,0.9375,0.9375,16.0,0.977654,0.996205,0.986842,527.0,0.933333,0.875,0.903226,16.0,0.964684,0.998077,0.981096,520.0,0.979466,1.0,0.989627,477.0


Best F1 score: 0.9874
Training time: 1325.76s
Samples/second: 6.65


In [9]:
# Save trained model
# trainer.save_model(OUTPUT_DIR)
# tokenizer.save_pretrained(OUTPUT_DIR)

# Validation evaluation
val_metrics = trainer.evaluate()

# Save metrics
with open(os.path.join(OUTPUT_DIR, "val_metrics.json"), "w", encoding="utf-8") as f:
    json.dump(val_metrics, f, ensure_ascii=False, indent=2)

# print(f"\nModel saved: {OUTPUT_DIR}")
print(f"Validation F1: {val_metrics.get('eval_f1', 0):.4f}")


Validation F1: 0.9874


In [10]:
# Load inference pipeline
pipe = pipeline(
    "token-classification",
    model=OUTPUT_DIR,
    tokenizer=OUTPUT_DIR,
    aggregation_strategy="simple",
    device=0 if torch.cuda.is_available() else -1,
)

def ner_long_text(pipe, text, chunk_chars=1400, overlap=300, score_thresh=0.30, batch_size=8):
    """
    Process long texts using sliding windows with entity fusion.
    Handles documents longer than model's maximum sequence length.
    """
    n = len(text)
    step = max(1, chunk_chars - overlap)
    windows = []
    offsets = []
    s = 0
    
    # Generate sliding windows
    while s < n:
        e = min(n, s + chunk_chars)
        windows.append(text[s:e])
        offsets.append(s)
        s += step

    # Batch processing
    results = []
    for i in range(0, len(windows), batch_size):
        batch = windows[i:i+batch_size]
        batch_off = offsets[i:i+batch_size]
        preds_list = pipe(batch)
        for off, preds in zip(batch_off, preds_list):
            for p in preds:
                if p["score"] >= score_thresh:
                    results.append({
                        "entity_group": p["entity_group"],
                        "score": float(p["score"]),
                        "word": p["word"],
                        "start": p["start"] + off,
                        "end":   p["end"]   + off,
                    })

    # Merge contiguous/overlapping entities
    results.sort(key=lambda r: r["start"])
    
    fused = []
    for r in results:
        if not fused:
            fused.append(r)
            continue
        last = fused[-1]
        same = (last["entity_group"] == r["entity_group"])
        overlap_or_touch = (r["start"] <= last["end"] + 1)
        if same and overlap_or_touch:
            last["end"] = max(last["end"], r["end"])
            last["score"] = max(last["score"], r["score"])
            last["word"] = text[last["start"]:last["end"]]
        else:
            fused.append(r)
    
    return fused


Device set to use cuda:0


In [11]:
# Test set evaluation
print("\nEvaluating test set...")
test_metrics = trainer.evaluate(eval_dataset=tokenized_datasets["test"])

# Save test metrics
with open(os.path.join(OUTPUT_DIR, "test_metrics.json"), "w", encoding="utf-8") as f:
    json.dump(test_metrics, f, ensure_ascii=False, indent=2)

# Performance summary table
print("\n" + "="*60)
print("PERFORMANCE SUMMARY")
print("="*60)
print(f"{'Metric':<20} {'Validation':<20} {'Test':<20}")
print("-"*60)
print(f"{'F1':<20} {val_metrics.get('eval_f1', 0):>19.4f} {test_metrics.get('eval_f1', 0):>19.4f}")
print(f"{'Precision':<20} {val_metrics.get('eval_precision', 0):>19.4f} {test_metrics.get('eval_precision', 0):>19.4f}")
print(f"{'Recall':<20} {val_metrics.get('eval_recall', 0):>19.4f} {test_metrics.get('eval_recall', 0):>19.4f}")
print("-"*60)


entity_metrics = []
for label in base_labels:
    f1_key = f"eval_f1_B-{label}"
    if f1_key in test_metrics:
        entity_metrics.append({
            'Entity': label,
            'F1': test_metrics[f1_key],
            'Precision': test_metrics.get(f"eval_precision_B-{label}", 0),
            'Recall': test_metrics.get(f"eval_recall_B-{label}", 0),
            'Support': int(test_metrics.get(f"eval_support_B-{label}", 0))
        })

if entity_metrics:
    df = pd.DataFrame(entity_metrics).sort_values('F1', ascending=False)
    print(df.to_string(index=False))



Evaluating test set...

                     precision    recall  f1-score   support

           CANTIDAD       1.00      1.00      1.00       974
                EAN       1.00      1.00      1.00       851
              FECHA       0.91      0.95      0.93        22
    NOMBRE_PRODUCTO       0.98      0.99      0.99       988
     NUMERO_FACTURA       0.91      0.91      0.91        22
PRECIO_COSTE_UNIDAD       1.00      1.00      1.00       978
                SKU       0.99      0.99      0.99       943

          micro avg       0.99      1.00      0.99      4778
          macro avg       0.97      0.98      0.97      4778
       weighted avg       0.99      1.00      0.99      4778


PERFORMANCE SUMMARY
Metric               Validation           Test                
------------------------------------------------------------
F1                                0.9874              0.9940
Precision                         0.9788              0.9923
Recall                            

In [12]:
# Guardar métricas y tiempos finales en un JSON
import time

# Recoger métricas globales y tiempos
train_runtime_sec = train_result.metrics.get("train_runtime", 0)
train_runtime_min = int(train_runtime_sec // 60)
train_runtime_rem_sec = int(train_runtime_sec % 60)

# Tiempo de inferencia en test
num_test_samples = len(tokenized_datasets["test"])
start_infer = time.time()
_ = trainer.predict(tokenized_datasets["test"])
end_infer = time.time()
infer_runtime_sec = end_infer - start_infer
infer_runtime_min = int(infer_runtime_sec // 60)
infer_runtime_rem_sec = int(infer_runtime_sec % 60)
avg_infer_per_sample = infer_runtime_sec / num_test_samples if num_test_samples else 0

final_metrics = {
    "train_runtime_sec": train_runtime_sec,
    "train_runtime_min": train_runtime_min,
    "train_runtime_rem_sec": train_runtime_rem_sec,
    "train_runtime_str": f"{train_runtime_min}m {train_runtime_rem_sec}s",
    "train_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
    "best_f1": trainer.state.best_metric,
    "val_f1": val_metrics.get("eval_f1", 0),
    "val_precision": val_metrics.get("eval_precision", 0),
    "val_recall": val_metrics.get("eval_recall", 0),
    "test_f1": test_metrics.get("eval_f1", 0),
    "test_precision": test_metrics.get("eval_precision", 0),
    "test_recall": test_metrics.get("eval_recall", 0),
    "infer_runtime_sec": infer_runtime_sec,
    "infer_runtime_min": infer_runtime_min,
    "infer_runtime_rem_sec": infer_runtime_rem_sec,
    "infer_runtime_str": f"{infer_runtime_min}m {infer_runtime_rem_sec}s",
    "avg_infer_per_sample_sec": avg_infer_per_sample,
    "avg_infer_per_sample_str": f"{avg_infer_per_sample:.3f}s",
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}

# Guardar en archivo
with open(os.path.join(OUTPUT_DIR, "final_metrics.json"), "w", encoding="utf-8") as f:
    json.dump(final_metrics, f, ensure_ascii=False, indent=2)

print("\nMétricas y tiempos finales guardados en final_metrics.json:")
print(json.dumps(final_metrics, indent=2, ensure_ascii=False))
print(f"\nTiempo de entrenamiento: {final_metrics['train_runtime_str']}")
print(f"Tiempo de inferencia en test: {final_metrics['infer_runtime_str']}")
print(f"Tiempo medio por factura en test: {final_metrics['avg_infer_per_sample_str']}")


                     precision    recall  f1-score   support

           CANTIDAD       1.00      1.00      1.00       974
                EAN       1.00      1.00      1.00       851
              FECHA       0.91      0.95      0.93        22
    NOMBRE_PRODUCTO       0.98      0.99      0.99       988
     NUMERO_FACTURA       0.91      0.91      0.91        22
PRECIO_COSTE_UNIDAD       1.00      1.00      1.00       978
                SKU       0.99      0.99      0.99       943

          micro avg       0.99      1.00      0.99      4778
          macro avg       0.97      0.98      0.97      4778
       weighted avg       0.99      1.00      0.99      4778


Métricas y tiempos finales guardados en final_metrics.json:
{
  "train_runtime_sec": 1325.7585,
  "train_runtime_min": 22,
  "train_runtime_rem_sec": 5,
  "train_runtime_str": "22m 5s",
  "train_samples_per_second": 6.653,
  "best_f1": 0.9874424445374634,
  "val_f1": 0.9874424445374634,
  "val_precision": 0.978838174273858