In [1]:
import warnings
warnings.filterwarnings("ignore")

from transformers import BartForConditionalGeneration, AutoTokenizer
import pandas as pd
from pathlib import Path
from tqdm import tqdm

from src.metrics import preds_time_tps_lacc
from src.inference_utils import GenerativeModel, GenerativeModelWithCaching

In [2]:
CURDIR = Path.cwd()

DATADIR = CURDIR / "data" / "original"
assert DATADIR.exists()

MODELS_DIR = CURDIR / "models"
assert MODELS_DIR.exists()

MODEL_ID = MODELS_DIR / 'distill_4-4'
assert MODEL_ID.exists()

MAX_LENGTH = 512

In [3]:
df = pd.read_csv(DATADIR / "test.csv", index_col=0, sep="\t")

df_unknown = df[df["split"] == "unknown"]
df_holdout = df[df["split"] == "holdout"]

df.shape[0], df_holdout.shape[0], df_unknown.shape[0]

(153991, 138882, 15109)

In [4]:
model = BartForConditionalGeneration.from_pretrained(MODEL_ID).to("cuda").eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [5]:
gen_model = GenerativeModel(
    model, tokenizer, verbose=False
)

In [6]:
model = GenerativeModelWithCaching(gen_model)

In [7]:
model._cache = dict()
_, elapsed_time, tps, lacc = preds_time_tps_lacc(model.predict, df.head(30000))
print(f"Elapsed time: {elapsed_time:.3f} seconds")
print(f"TPS: {tps:.3f}")
print(f"lAcc: {lacc:.3f}")
print()

model._cache = dict()
_, elapsed_time, tps, lacc = preds_time_tps_lacc(model.predict, df_holdout.head(30000))
print(f"Elapsed time: {elapsed_time:.3f} seconds")
print(f"TPS: {tps:.3f}")
print(f"lAcc: {lacc:.3f}")
print()

model._cache = dict()
_, elapsed_time, tps, lacc = preds_time_tps_lacc(model.predict, df_unknown)
print(f"Elapsed time: {elapsed_time:.3f} seconds")
print(f"TPS: {tps:.3f}")
print(f"lAcc: {lacc:.3f}")
print()

469it [00:26, 17.64it/s]                         


Elapsed time: 26.597 seconds
TPS: 1127.954
lAcc: 0.960



469it [00:21, 21.34it/s]                         


Elapsed time: 21.981 seconds
TPS: 1364.794
lAcc: 0.983



237it [00:27,  8.49it/s]                         


Elapsed time: 27.913 seconds
TPS: 541.295
lAcc: 0.898

