In [1]:
import pandas as pd
import numpy as np
import torch
import time
from torch.utils.flop_counter import FlopCounterMode
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from src.utils.data_utils import random_masking

%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import DataLoader, random_split
from tokenizers import Tokenizer
from src.utils.data_utils import SeqSet, Seq, random_masking

In [3]:
tk = Tokenizer.from_file(
    "./dataset/instacart/data/tk.json"
)

In [4]:
from src.models import FMConfig
from src.utils.model_utils import build_model
from src.utils.train_utils import (
    load_cfg,
    build_trainer,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


### bert

In [5]:
instacart = Seq(
    tokenizer=tk,
    data_root="/hpc/group/engelhardlab/ms1008/instacart",
    data_folder="/hpc/group/engelhardlab/ms1008/instacart/instacart.parquet",
    max_seq=2048,
    downstream_task_cohort=None,
    outcome_vars=None,
    time_operation=lambda x: x["t"],
    seq_id_col="user_id",
    set_id_col="order_number",
    token_col="product_id",
    additional_cols=["t"]
)
train, valid = random_split(
    dataset=instacart,
    lengths=[0.9, 0.1],
    generator=torch.Generator().manual_seed(42),
)

In [6]:
dataloader = DataLoader(train, batch_size=16, shuffle=True)
batch = next(iter(dataloader))

In [7]:
cfg_dict = load_cfg("./config/instacart_bert.yaml")
tk = Tokenizer.from_file(f"./{cfg_dict['model']['tokenizer']}")

cfg = FMConfig(
    vocab_size=tk.get_vocab_size(),
    dataset=cfg_dict["dataset"],
    trainer=cfg_dict["trainer"],
    **cfg_dict["model"],
)
model = build_model(cfg, "FMBert", device)
trainer = build_trainer(cfg, model, tk, device)

In [8]:
from torch.utils.flop_counter import FlopCounterMode

with FlopCounterMode(display=False) as flop_counter:
    logits, h = model(
        batch["input_ids"].to(device),
        batch["attention_mask"].to(device),
        batch["set_attention_mask"].to(model.device),
        batch["t"].to(device),
    )
    
total_flops = flop_counter.get_total_flops()
print(f"Total FLOPs: {total_flops / 1e9:.2f} GFLOPs")

100%|██████████| 30/30 [00:09<00:00,  3.03it/s]

Mean FLOPs: 6285.06 GFLOPs
Std  FLOPs: 0.00 GFLOPs





In [9]:
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(f"Total number of parameters: {num_params / 1e6:.2f} M")

Total number of parameters: 77.11 M


In [10]:
step_time_list = []
for _ in tqdm(range(30)):
    batch = next(iter(dataloader))

    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    t = batch["t"].to(device)
    
    masked_input_ids, labels = random_masking(
        input_ids.clone(), tk, 0.2
    )
    
    torch.cuda.synchronize()
    t0 = time.time()
    logits, _ = model(
        input_ids=masked_input_ids,
        attention_mask=attention_mask,
        t=t,
    )
    loss = CrossEntropyLoss(ignore_index=-100)(
        logits.view(-1, logits.size(-1)), labels.view(-1)
    )
    loss.backward()
    torch.cuda.synchronize()
    t1 = time.time()
    step_time = t1 - t0

    step_time_list.append(step_time)
    
    del logits, _, batch
    torch.cuda.empty_cache()

    

100%|██████████| 30/30 [00:21<00:00,  1.40it/s]


In [11]:
print(f"Mean step time: {np.mean(step_time_list):.2f} sec")
print(f"Std  step time: {np.std(step_time_list):.2f} sec")
print(f"Mean throughput: {(2048 / np.array(step_time_list)).mean():.2f} #tokens/sec")
print(f"Std  throughput: {(2048 / np.array(step_time_list)).std():.2f} #tokens/sec")

Mean step time: 0.54 sec
Std  step time: 0.05 sec
Mean throughput: 3815.27 #tokens/sec
Std  throughput: 303.95 #tokens/sec


### base

In [12]:
instacart = SeqSet(
    tokenizer=tk,
    data_root="/hpc/group/engelhardlab/ms1008/instacart",
    data_folder="/hpc/group/engelhardlab/ms1008/instacart/instacart.parquet",
    max_seq=64,
    max_set_size=32,
    downstream_task_cohort=None,
    outcome_vars=None,
    time_operation=lambda x: x["t"],
    seq_id_col="user_id",
    set_id_col="order_number",
    token_col="product_id",
    additional_cols=["t"]
)
train, valid = random_split(
    dataset=instacart,
    lengths=[0.9, 0.1],
    generator=torch.Generator().manual_seed(42),
)

In [13]:
cfg_dict = load_cfg("./config/instacart_base.yaml")
tk = Tokenizer.from_file(f"./{cfg_dict['model']['tokenizer']}")

cfg = FMConfig(
    vocab_size=tk.get_vocab_size(),
    dataset=cfg_dict["dataset"],
    trainer=cfg_dict["trainer"],
    **cfg_dict["model"],
)
model = build_model(cfg, "FMBase", device)
trainer = build_trainer(cfg, model, tk, device)

In [14]:
dataloader = DataLoader(train, batch_size=16, shuffle=True)
batch = next(iter(dataloader))

In [15]:
from torch.utils.flop_counter import FlopCounterMode

with FlopCounterMode(display=False) as flop_counter:
    logits, h = model(
        batch["input_ids"].to(device),
        batch["attention_mask"].to(device),
        batch["set_attention_mask"].to(model.device),
        batch["t"].to(device),
    )
    
total_flops = flop_counter.get_total_flops()
print(f"Total FLOPs: {total_flops / 1e9:.2f} GFLOPs")

Total FLOPs: 5155.62 GFLOPs


In [16]:
num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(f"Total number of parameters: {num_params / 1e6:.2f} M")

Total number of parameters: 119.62 M


In [17]:
step_time_list = []
for _ in tqdm(range(30)):
    batch = next(iter(dataloader))

    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    set_attention_mask = batch["set_attention_mask"].to(device)
    t = batch["t"].to(device)
    
    masked_input_ids, labels = random_masking(
        input_ids.clone(), tk, 0.2
    )
    
    torch.cuda.synchronize()
    t0 = time.time()
    logits, _ = model(
        input_ids=masked_input_ids,
        attention_mask=attention_mask,
        set_attention_mask=set_attention_mask,
        t=t,
    )
    loss = CrossEntropyLoss(ignore_index=-100)(
        logits.view(-1, logits.size(-1)), labels.view(-1)
    )
    loss.backward()
    torch.cuda.synchronize()
    t1 = time.time()
    step_time = t1 - t0

    step_time_list.append(step_time)
    
    del logits, _, batch
    torch.cuda.empty_cache()

100%|██████████| 30/30 [00:19<00:00,  1.57it/s]


In [18]:
print(f"Mean step time: {np.mean(step_time_list):.2f} sec")
print(f"Std  step time: {np.std(step_time_list):.2f} sec")
print(f"Mean throughput: {(2048 / np.array(step_time_list)).mean():.2f} #tokens/sec")
print(f"Std  throughput: {(2048 / np.array(step_time_list)).std():.2f} #tokens/sec")

Mean step time: 0.40 sec
Std  step time: 0.01 sec
Mean throughput: 5099.06 #tokens/sec
Std  throughput: 95.67 #tokens/sec
