In [None]:
import os
import sys
import torch
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from loguru import logger
from typing import List, Dict
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForTokenClassification, AutoTokenizer
from torch.nn.utils.rnn import pad_sequence
from datasets import Dataset

device = 'cuda'

log_format = "<level>{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}</level>"

logger.remove()
logger.add(sys.stdout, format=log_format, level='INFO', colorize=True)

In [None]:
label2id = {
    'O': 0,
    'B-NAME_STUDENT': 1,
    'B-EMAIL': 2,
    'B-USERNAME': 3,
    'B-ID_NUM': 4,
    'B-PHONE_NUM': 5,
    'B-URL_PERSONAL': 6,
    'B-STREET_ADDRESS': 7,
    'I-NAME_STUDENT': 8,
    'I-EMAIL': 9,
    'I-USERNAME': 10,
    'I-ID_NUM': 11,
    'I-PHONE_NUM': 12,
    'I-URL_PERSONAL': 13,
    'I-STREET_ADDRESS': 14
}

id2label = {
    0: 'O',
    1: 'B-NAME_STUDENT',
    2: 'B-EMAIL',
    3: 'B-USERNAME',
    4: 'B-ID_NUM',
    5: 'B-PHONE_NUM',
    6: 'B-URL_PERSONAL',
    7: 'B-STREET_ADDRESS',
    8: 'I-NAME_STUDENT',
    9: 'I-EMAIL',
    10: 'I-USERNAME',
    11: 'I-ID_NUM',
    12: 'I-PHONE_NUM',
    13: 'I-URL_PERSONAL',
    14: 'I-STREET_ADDRESS'
}

In [None]:
def get_data(path):
    with open(path, 'r') as fp:
        data = json.load(fp)
        
    x = Dataset.from_dict(data)
    logger.info(f'Size of dataset{len(x)}')

    return x


def get_model(model_path, tokenizer_path):
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    model.to(device)
    return model, tokenizer


def eval_model(trained_model, eval_dataset):
    metrics = {
        'document': [],
        'token': [],
        'label': []
    }
    
    trained_model.eval()
    
    with torch.no_grad():
        for batch in tqdm(eval_dataset, total=len(eval_dataset)):
    
            document_ids = batch['document_id']
            input_ids = torch.tensor(batch['input_ids']).reshape(-1, 1).to(device)
            attention_mask = torch.tensor(batch['attention_mask']).reshape(-1, 1).to(device)
    
            outputs = model(input_ids, attention_mask=attention_mask)
    
            _, predicted_labels = torch.max(outputs.logits, -1)
    
            idx = 0
    
            for _, a, p in zip(input_ids, attention_mask, predicted_labels):
                if p == 0 or a == 0:
                    idx += 1
                    continue
    
                metrics['document'].append(document_ids)
                metrics['token'].append(idx)
                metrics['label'].append(id2label[p.item()])
                idx += 1
    return metrics

In [None]:
dataset_path = './data/processed/test/test_processed.json'
tokenizer_path = './model/20240402_2108//tokenizer/'
model_path = './model/20240402_2108/model/'

In [None]:
model, tokenizer = get_model(model_path, tokenizer_path)
test_ds = get_data(dataset_path)

In [None]:
test_metrics = eval_model(
    trained_model=model,
    eval_dataset=test_ds
)

In [None]:
df = pd.DataFrame(test_metrics)
df.reset_index(inplace=True)
df.columns = ['row_id', 'document', 'token', 'label']

In [None]:
tokenizer.decode(test_ds[0]['input_ids'])

In [None]:
df

### Rough

In [None]:
tokenizer.decode(test_ds[-1]['input_ids'])

In [None]:
z = test_ds[-1]
o = model(torch.tensor(z['input_ids']).reshape(1, -1).to('cuda'), torch.tensor(z['attention_mask']).reshape(1, -1).to('cuda'))

In [None]:
for pred, token in zip(o.logits.argmax(-1)[0], z['input_ids']):
    if pred.item() == 0:
        continue
    print(f'{tokenizer.decode(token)} ==> {id2label[pred.item()]}')

In [None]:
tokenizer.decode(ds['train'][0]['input_ids'])

In [None]:
o.logits.shape

In [None]:
_, p = torch.max(o.logits, -1)

In [None]:
o.logits[0][0]

In [None]:
from datasets import load_from_disk

In [None]:
ds = load_from_disk('data/processed/dataset_3/')

In [None]:
batch = torch.tensor(ds['train'][0]['input_ids']).reshape(1, -1).to('cuda')

In [None]:
batch.shape

In [None]:
o = model(batch)