In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForTokenClassification, AutoTokenizer

import torch
from torch.nn.utils.rnn import pad_sequence

from datasets import Dataset
from model import get_model
from data_utils import id2label, chunk_examples_infer, tokenizer_and_align_infer

from loguru import logger

device = 'cuda'

In [None]:
def get_data(path, tokenizer):
    with open(path, 'r') as fp:
        data = json.load(fp)

    data = pd.DataFrame(data)
    x = Dataset.from_pandas(data)

    # x = Dataset.from_list(data)
    x = x.map(tokenizer_and_align_infer, num_proc=1, fn_kwargs={'tokenizer': tokenizer})

    x = x.map(
        chunk_examples_infer,
        num_proc=1,
        batched=True,
        batch_size=10,
        remove_columns=x.column_names,
        fn_kwargs={'max_len': 256}
    )

    logger.info(f'Size of dataset{len(x)}')

    return x


def get_model(path):
    model = AutoModelForTokenClassification.from_pretrained(
        os.path.join(path, 'model'), ignore_mismatched_sizes=True
    )
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(path, 'tokenizer'))
    model.to(device)
    return model, tokenizer


def eval_model(trained_model, eval_dataset):
    metrics = {
        'document': [],
        'token': [],
        'label': []
    }

    trained_model.eval()

    with torch.no_grad():
        for s in tqdm(range(0, len(eval_dataset))):
            batch = eval_dataset[s]

            document_ids = batch['document_id']
            input_ids = stack(batch['input_ids']).to(device)
            attention_mask = stack(batch['attention_mask']).to(device)

            outputs = model(input_ids, attention_mask=attention_mask)

            _, predicted_labels = torch.max(outputs.logits, -1)

            idx = 0

            for _, a, p in zip(input_ids, attention_mask, predicted_labels):
                if p == 0 or a == 0:
                    idx += 1
                    continue

                metrics['document'].append(document_ids[0])
                metrics['token'].append(idx)
                metrics['label'].append(label2id[p])
                idx += 1

    return metrics

In [None]:
dataset_path = './data/test.json'
model_path = './model/20240402_0853/'

In [None]:
model, tokenizer = get_model(model_path)
test_ds = get_data(path=dataset_path, tokenizer=tokenizer)

In [None]:
test_dict = test_ds.to_dict()

In [None]:
with open('./data/processed/test/test_processed.json', 'w') as fp:
    json.dump(test_dict, fp)