In [1]:
import numpy as np
import torch
import time
from torch.utils.flop_counter import FlopCounterMode
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import DataLoader, random_split
from tokenizers import Tokenizer
from src.utils.data_utils import SeqSet, Seq, random_masking

In [3]:
tk = Tokenizer.from_file(
    "./dataset/instacart/data/tk.json"
)

In [4]:
from src.models import FMConfig
from src.utils.model_utils import build_model
from src.utils.train_utils import (
    load_cfg,
    build_trainer,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
CONTEXT_TOKENS = 2048
BATCH_SIZE = 16

### bert

In [None]:
instacart = Seq(
    tokenizer=tk,
    data_root="./dataset/instacart/data",
    data_folder="./dataset/instacart/data/instacart.parquet",
    max_seq=CONTEXT_TOKENS,
    downstream_task_cohort=None,
    outcome_vars=None,
    time_operation=lambda x: x["t"],
    seq_id_col="user_id",
    set_id_col="order_number",
    token_col="product_id",
    additional_cols=["t"]
)
train, valid = random_split(
    dataset=instacart,
    lengths=[0.9, 0.1],
    generator=torch.Generator().manual_seed(42),
)

In [None]:
dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
batch = next(iter(dataloader))

In [12]:
cfg_dict = load_cfg("./config/instacart_bert.yaml")
tk = Tokenizer.from_file(f"./{cfg_dict['model']['tokenizer']}")

cfg = FMConfig(
    vocab_size=tk.get_vocab_size(),
    dataset=cfg_dict["dataset"],
    trainer=cfg_dict["trainer"],
    **cfg_dict["model"],
)
model = build_model(cfg, "FMBert", device)
trainer = build_trainer(cfg, model, tk, device)

In [None]:
with FlopCounterMode(display=False) as flop_counter:
    logits, h = model(
        batch["input_ids"].to(device),
        batch["attention_mask"].to(device),
        batch["t"].to(device),
    )
    
total_flops = flop_counter.get_total_flops()
print(f"Total FLOPs: {total_flops / 1e9:.2f} GFLOPs")

In [None]:
print(f"Total FLOPs: {total_flops / (CONTEXT_TOKENS * BATCH_SIZE) / 1e9:.2f} GFLOPs")

In [13]:
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(f"Total number of parameters: {num_params / 1e6:.2f} M")

Total number of parameters: 77.12 M


In [None]:
step_time_list = []
for _ in tqdm(range(30)):
    batch = next(iter(dataloader))

    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    t = batch["t"].to(device)
    
    masked_input_ids, labels = random_masking(
        input_ids.clone(), tk, 0.2
    )
    
    torch.cuda.synchronize()
    t0 = time.time()
    logits, _ = model(
        input_ids=masked_input_ids,
        attention_mask=attention_mask,
        t=t,
    )
    loss = CrossEntropyLoss(ignore_index=-100)(
        logits.view(-1, logits.size(-1)), labels.view(-1)
    )
    loss.backward()
    torch.cuda.synchronize()
    t1 = time.time()
    step_time = t1 - t0

    step_time_list.append(step_time)
    
    del logits, _, batch
    torch.cuda.empty_cache()

    

In [None]:
print(f"Mean step time: {np.mean(step_time_list):.2f} sec")
print(f"Std  step time: {np.std(step_time_list):.2f} sec")
print(f"Mean throughput: {(CONTEXT_TOKENS / np.array(step_time_list)).mean():.2f} #tokens/sec")
print(f"Std  throughput: {(CONTEXT_TOKENS / np.array(step_time_list)).std():.2f} #tokens/sec")


#### Longformer

In [14]:
cfg_dict = load_cfg("./config/instacart_longformer.yaml")
tk = Tokenizer.from_file(f"./{cfg_dict['model']['tokenizer']}")

cfg = FMConfig(
    vocab_size=tk.get_vocab_size(),
    dataset=cfg_dict["dataset"],
    trainer=cfg_dict["trainer"],
    **cfg_dict["model"],
)
model = build_model(cfg, "FMLongformer", device)
trainer = build_trainer(cfg, model, tk, device)

In [None]:
dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
batch = next(iter(dataloader))

In [8]:
with FlopCounterMode(display=False) as flop_counter:
    logits, h = model(
        batch["input_ids"].to(device),
        batch["attention_mask"].to(device),
        batch["t"].to(device),
    )
    
total_flops = flop_counter.get_total_flops()
print(f"Total FLOPs: {total_flops / 1e9:.2f} GFLOPs")

Total FLOPs: 1356.14 GFLOPs


In [None]:
print(f"Total FLOPs: {total_flops / (CONTEXT_TOKENS * BATCH_SIZE) / 1e9:.2f} GFLOPs")

In [None]:
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(f"Total number of parameters: {num_params / 1e6:.2f} M")

In [None]:
step_time_list = []
for _ in tqdm(range(30)):
    batch = next(iter(dataloader))

    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    t = batch["t"].to(device)
    
    masked_input_ids, labels = random_masking(
        input_ids.clone(), tk, 0.2
    )
    
    torch.cuda.synchronize()
    t0 = time.time()
    logits, _ = model(
        input_ids=masked_input_ids,
        attention_mask=attention_mask,
        t=t,
    )
    loss = CrossEntropyLoss(ignore_index=-100)(
        logits.view(-1, logits.size(-1)), labels.view(-1)
    )
    loss.backward()
    torch.cuda.synchronize()
    t1 = time.time()
    step_time = t1 - t0

    step_time_list.append(step_time)
    
    del logits, _, batch
    torch.cuda.empty_cache()


In [None]:
print(f"Mean step time: {np.mean(step_time_list):.2f} sec")
print(f"Std  step time: {np.std(step_time_list):.2f} sec")
print(f"Mean throughput: {(CONTEXT_TOKENS / np.array(step_time_list)).mean():.2f} #tokens/sec")
print(f"Std  throughput: {(CONTEXT_TOKENS / np.array(step_time_list)).std():.2f} #tokens/sec")

### base

In [None]:
instacart = SeqSet(
    tokenizer=tk,
    data_root="./dataset/instacart/data",
    data_folder="./dataset/instacart/data/instacart.parquet",
    max_seq=64,
    max_set_size=32,
    downstream_task_cohort=None,
    outcome_vars=None,
    time_operation=lambda x: x["t"],
    seq_id_col="user_id",
    set_id_col="order_number",
    token_col="product_id",
    additional_cols=["t"]
)
train, valid = random_split(
    dataset=instacart,
    lengths=[0.9, 0.1],
    generator=torch.Generator().manual_seed(42),
)

In [None]:
cfg_dict = load_cfg("./config/instacart_base.yaml")
tk = Tokenizer.from_file(f"./{cfg_dict['model']['tokenizer']}")

cfg = FMConfig(
    vocab_size=tk.get_vocab_size(),
    dataset=cfg_dict["dataset"],
    trainer=cfg_dict["trainer"],
    **cfg_dict["model"],
)
model = build_model(cfg, "FMBase", device)
trainer = build_trainer(cfg, model, tk, device)

In [None]:
dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
batch = next(iter(dataloader))

In [None]:
with FlopCounterMode(display=False) as flop_counter:
    logits, h = model(
        batch["input_ids"].to(device),
        batch["attention_mask"].to(device),
        batch["set_attention_mask"].to(model.device),
        batch["t"].to(device),
    )
    
total_flops = flop_counter.get_total_flops()
print(f"Total FLOPs: {total_flops / 1e9:.2f} GFLOPs")

In [None]:
print(f"Total FLOPs: {total_flops / (CONTEXT_TOKENS * BATCH_SIZE) / 1e9:.2f} GFLOPs")


In [None]:
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(f"Total number of parameters: {num_params / 1e6:.2f} M")

In [None]:
step_time_list = []
for _ in tqdm(range(30)):
    batch = next(iter(dataloader))

    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    set_attention_mask = batch["set_attention_mask"].to(device)
    t = batch["t"].to(device)
    
    masked_input_ids, labels = random_masking(
        input_ids.clone(), tk, 0.2
    )
    
    torch.cuda.synchronize()
    t0 = time.time()
    logits, _ = model(
        input_ids=masked_input_ids,
        attention_mask=attention_mask,
        set_attention_mask=set_attention_mask,
        t=t,
    )
    loss = CrossEntropyLoss(ignore_index=-100)(
        logits.view(-1, logits.size(-1)), labels.view(-1)
    )
    loss.backward()
    torch.cuda.synchronize()
    t1 = time.time()
    step_time = t1 - t0

    step_time_list.append(step_time)
    
    del logits, _, batch
    torch.cuda.empty_cache()

In [None]:
print(f"Mean step time: {np.mean(step_time_list):.2f} sec")
print(f"Std  step time: {np.std(step_time_list):.2f} sec")
print(f"Mean throughput: {(CONTEXT_TOKENS / np.array(step_time_list)).mean():.2f} #tokens/sec")
print(f"Std  throughput: {(CONTEXT_TOKENS / np.array(step_time_list)).std():.2f} #tokens/sec")