# Env

In [None]:
import os
import argparse
import collections
from datetime import datetime
import re
import json
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.auto import tqdm
from transformers import (
    T5TokenizerFast,
    AutoTokenizer
)

In [None]:
# Gradient False
torch.set_grad_enabled(True)
# work dir
work_dir = '/Users/cchyun/Workspace/nlp_ws/nlp-practice'

In [None]:
%cd {work_dir}
!pwd

# 4.1 Bert TC

In [None]:
%cd {work_dir}/src/bert-tc
!pwd

## train bert tc

In [None]:
!sh finetune_bert_tc.sh "cchyun-bert-tc"

## bert classify

In [None]:
# run src/tc/classify_rnn.sh
!sh classify_bert.sh "../../checkpoints/cchyun-bert-tc-20240323-103852/checkpoint-1052"

## bert infer

In [None]:
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

In [None]:
device = (
    torch.device("cpu")
)

model_fn = "../../checkpoints/cchyun-bert-tc-20240323-103852/checkpoint-1052"

with open(os.path.join(model_fn, "..", "config.json")) as f:
    data = json.loads(f.read())

train_config = argparse.Namespace(**data["config"])
label2idx = data["label2idx"]
idx2label = {int(k): v for k, v in data["idx2label"].items()}

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_fn)
tokenizer = AutoTokenizer.from_pretrained(model_fn)

model.eval()
model.to(device)

In [None]:
while True:
    print("input> ", end="")
    line = str(input())
    if len(line) == 0:
        break

    x = tokenizer(
        line,
        truncation=True,
        max_length=train_config.max_length,
        return_tensors="pt",
    ).to(device)

    logit = model(**x).logits[0]
    prob = F.softmax(logit, dim=-1)
    # |prob| = (batch_size, output_dim)

    y = prob.argmax(dim=-1)
    # |y| = (batch_size,)

    print(f"{idx2label[y.item()]}\t{prob[y].item():.4f}\t{line}")