In [None]:
import dask.dataframe as dd
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
from tqdm import tqdm
import os

# -----------------------
# CONFIG
# -----------------------
PARQUET_FILE = "natural_instructions_sample_balanced.parquet"
OUTPUT_FILE = "llm_inference_results.parquet"
MODEL_NAME = "google/flan-t5-large"  # encoder-decoder model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TARGET_BATCH_SIZE = 4   # adjust based on GPU memory
MAX_NEW_TOKENS = 128
N_PARTITIONS = 500       # number of Dask partitions

CHUNKED_SAVE = True      # save after each partition to avoid large RAM usage
SAVE_DIR = "results_chunks_"+MODEL_NAME.replace("/", "_")
os.makedirs(SAVE_DIR, exist_ok=True)

# include targets now
cols = ["id", "task_name", "task_family", "definition", "inputs", "targets"]

# -----------------------
# 1. Load model & tokenizer
# -----------------------
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    device_map="auto"  # automatically maps layers across CPU/GPU
)
model.eval()

# -----------------------
# 2. Read Parquet lazily with Dask
# -----------------------
print("Reading dataset lazily with Dask...")
ddf = dd.read_parquet(PARQUET_FILE, columns=cols).repartition(npartitions=N_PARTITIONS)
print(f"Dataset has {len(ddf):,} rows across {ddf.npartitions} partitions")

# -----------------------
# 3. Batch inference per partition
# -----------------------
all_results = []

for part_idx in tqdm(range(ddf.npartitions), desc="Processing partitions"):
    chunk_df = ddf.get_partition(part_idx).compute()
    
    if len(chunk_df) == 0:
        continue
    
    # Concatenate definition + inputs only for LLM input
    llm_texts = (chunk_df["definition"] + " " + chunk_df["inputs"]).tolist()
    
    i = 0
    while i < len(llm_texts):
        batch_size = TARGET_BATCH_SIZE
        batch_success = False
        
        while not batch_success:
            try:
                batch_texts = llm_texts[i:i+batch_size]
                
                # Tokenize
                inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
                
                # Run model & measure latency
                start_time = time.time()
                outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
                end_time = time.time()
                
                # Decode
                batch_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                batch_latency = (end_time - start_time) / len(batch_texts)
                
                # Save results keeping definition & inputs separate, plus targets
                for j, text in enumerate(batch_texts):
                    all_results.append({
                        "id": chunk_df.iloc[i+j]["id"],
                        "task_name": chunk_df.iloc[i+j]["task_name"],
                        "task_family": chunk_df.iloc[i+j]["task_family"],
                        "definition": chunk_df.iloc[i+j]["definition"],
                        "inputs": chunk_df.iloc[i+j]["inputs"],
                        "targets": chunk_df.iloc[i+j]["targets"],       # <-- added
                        "output_text": batch_outputs[j],
                        "latency_sec": batch_latency
                    })
                
                batch_success = True
                i += batch_size
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    batch_size = max(1, batch_size // 2)
                    torch.cuda.empty_cache()
                    print(f"OOM detected. Reducing batch size to {batch_size}")
                else:
                    raise e
    
    # Optionally save per partition to avoid large RAM usage
    if CHUNKED_SAVE:
        part_file = os.path.join(SAVE_DIR, f"results_part_{part_idx}.parquet")
        pd.DataFrame(all_results).to_parquet(part_file, index=False)
        all_results = []  # reset for next partition

# -----------------------
# 4. If CHUNKED_SAVE=False, save all results at once
# -----------------------
if not CHUNKED_SAVE:
    results_df = pd.DataFrame(all_results)
    results_df.to_parquet(OUTPUT_FILE, index=False)
    print(f"Saved results to {OUTPUT_FILE}")

print("Inference completed.")


Loading model and tokenizer...


2025-11-03 07:21:29.861612: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762154489.885778    1256 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762154489.895844    1256 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-03 07:21:30.101898: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Reading dataset lazily with Dask...
Dataset has 98,417 rows across 500 partitions


Processing partitions:   0%|          | 2/500 [01:20<5:27:49, 39.50s/it]