In [1]:
import math
import time
from dataclasses import dataclass
from typing import List, Dict, Any

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer

import sys
sys.path.append('..')
from language.transformer import Transformer


In [None]:
# Models to compare
MODEL_1 = "Julianvn/facts_ts-final"      # bilinear (custom Transformer)
MODEL_2 = "tdooms/ts-medium"             # bilinear (custom Transformer)
#MODEL_3 = "tdooms/ts-medium-reglu"       # ReGLU (standard HF model)

MODEL_4 = "Julianvn/facts-fw-small"
MODEL_5 = "Julianvn/facts-fw-med-new"
MODEL_6 = "tdooms/fw-small"
MODEL_7 = "tdooms/fw-medium"

SEQ_LEN = 256
BATCH_SIZE = 16
MAX_BATCHES = 200 

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

if DEVICE.type == "cuda":
    DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
elif DEVICE.type == "mps":
    DTYPE = torch.float16
else:
    DTYPE = torch.float32

print("Device:", DEVICE)
print("DType :", DTYPE)


Device: mps
DType : torch.float16


In [3]:
@dataclass
class EvalResult:
    model: str
    params_m: float
    tokens: int
    loss: float
    ppl: float
    toks_per_s: float
    device: str
    dtype: str


def count_params(model: torch.nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


In [None]:
dataset = load_from_disk("ts-tokenized-final").with_format("torch")
dataset = dataset['validation']
print(dataset)

tokenizer = Transformer.get_tokenizer("ts-4096")


Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 16384
})


In [None]:
def collate_fn(batch: List[Dict[str, Any]]):
    input_ids = torch.stack([ex["input_ids"][:SEQ_LEN] for ex in batch])
    attention_mask = torch.stack([ex["attention_mask"][:SEQ_LEN] for ex in batch])
    
    return {
        "input_ids": input_ids.to(DEVICE),
        "attention_mask": attention_mask.to(DEVICE),
    }

In [None]:
@torch.no_grad()
def evaluate_model(model_name: str) -> EvalResult:
    print(f"\nEvaluating {model_name}")
    
    model = Transformer.from_pretrained(
        model_name,
        device=str(DEVICE),
    ).to(DTYPE)
    
    print(f"  Architecture: {'Bilinear' if model.config.bilinear else 'ReGLU'}")
    if model.config.gate:
        print(f"  Gate: {model.config.gate}")
    
    model.eval()

    params_m = count_params(model) / 1e6

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
    )

    warmup = next(iter(loader))
    _ = model(**warmup)
    if DEVICE.type == "cuda":
        torch.cuda.synchronize()

    total_loss = 0.0
    total_tokens = 0
    batches_done = 0

    t0 = time.perf_counter()

    for batch in loader:
        if batches_done >= MAX_BATCHES:
            break

        outputs = model(**batch)
        logits = outputs.logits

        input_ids = batch["input_ids"]
        attn = batch["attention_mask"]

        shift_logits = logits[:, :-1]
        shift_labels = input_ids[:, 1:]
        shift_attn = attn[:, 1:]

        ce = F.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            reduction="none",
        ).view_as(shift_labels)

        ce = ce * shift_attn
        tokens = int(shift_attn.sum().item())

        total_loss += float(ce.sum().item())
        total_tokens += tokens
        batches_done += 1

    if DEVICE.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()

    mean_loss = total_loss / total_tokens
    ppl = math.exp(mean_loss)
    toks_per_s = total_tokens / (t1 - t0)

    return EvalResult(
        model=model_name,
        params_m=params_m,
        tokens=total_tokens,
        loss=mean_loss,
        ppl=ppl,
        toks_per_s=toks_per_s,
        device=str(DEVICE),
        dtype=str(DTYPE).replace("torch.", ""),
    )


In [11]:
model2_res = evaluate_model(MODEL_2)


Evaluating tdooms/ts-medium
  Architecture: Bilinear


In [None]:
res_a = evaluate_model(MODEL_1)
res_b = evaluate_model(MODEL_2)



Evaluating Julianvn/facts_ts-final
  Architecture: Bilinear

Evaluating tdooms/ts-medium
  Architecture: Bilinear

Evaluating tdooms/ts-medium-reglu


KeyError: None

In [18]:
res_4 = evaluate_model(MODEL_4)
res_5 = evaluate_model(MODEL_5)
res_6 = evaluate_model(MODEL_6)
res_7 = evaluate_model(MODEL_7)


Evaluating Julianvn/facts-fw-small
  Architecture: Bilinear

Evaluating Julianvn/facts-fw-med-new
  Architecture: Bilinear

Evaluating tdooms/fw-small
  Architecture: Bilinear

Evaluating tdooms/fw-medium
  Architecture: Bilinear


In [20]:
def print_table(results: List[EvalResult]):
    headers = ["model", "params(M)", "tokens", "loss", "ppl", "tok/s", "device", "dtype"]
    rows = [
        [
            r.model,
            f"{r.params_m:.2f}",
            f"{r.tokens}",
            f"{r.loss:.4f}",
            f"{r.ppl:.2f}",
            f"{r.toks_per_s:.0f}",
            r.device,
            r.dtype,
        ]
        for r in results
    ]

    widths = [max(len(h), max(len(row[i]) for row in rows)) for i, h in enumerate(headers)]

    def fmt(row):
        return " | ".join(row[i].ljust(widths[i]) for i in range(len(row)))

    sep = "-+-".join("-" * w for w in widths)

    print(fmt(headers))
    print(sep)
    for row in rows:
        print(fmt(row))


print_table([res_4, res_6, res_5, res_7])


model                     | params(M) | tokens | loss    | ppl       | tok/s | device | dtype  
--------------------------+-----------+--------+---------+-----------+-------+--------+--------
Julianvn/facts-fw-small   | 162.48    | 617846 | 11.6579 | 115606.46 | 5887  | mps    | float16
tdooms/fw-small           | 162.40    | 617846 | 11.7947 | 132548.80 | 6339  | mps    | float16
Julianvn/facts-fw-med-new | 334.12    | 617846 | 11.9243 | 150895.62 | 2962  | mps    | float16
tdooms/fw-medium          | 333.97    | 617846 | 12.2431 | 207534.08 | 2890  | mps    | float16


In [22]:
print_table([res_a, res_b])


model                   | params(M) | tokens | loss   | ppl   | tok/s | device | dtype  
------------------------+-----------+--------+--------+-------+-------+--------+--------
Julianvn/facts_ts-final | 29.39     | 617846 | 1.2770 | 3.59  | 31526 | mps    | float16
tdooms/ts-medium        | 29.36     | 617846 | 3.1509 | 23.36 | 31977 | mps    | float16
