In [None]:
"""
Boilerplate: BART inference on a dataset (CSV/JSON/TSV)

Features:
- Loads dataset with `datasets` (CSV/JSON/TSV) or plain text file
- Loads a BART-family model from Hugging Face (e.g., facebook/bart-large-cnn or facebook/bart-base)
- Tokenizes and runs batch generation on GPU if available
- Supports generation arguments: max_length, num_beams, do_sample, top_k, top_p, temperature
- Writes outputs to CSV with original input + generated text

Usage examples:
python bart_inference_boilerplate.py \
  --model facebook/bart-large-cnn \
  --input-file input.csv \
  --input-column text \
  --output-file outputs.csv \
  --batch-size 8 \
  --max-length 128 \
  --num-beams 4

Requirements:
- transformers
- datasets
- torch

Install: pip install transformers datasets torch
"""

from typing import List, Dict, Optional
import argparse
import csv
import os
import sys
from pathlib import Path

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset


def parse_args():
    parser = argparse.ArgumentParser(description="BART inference boilerplate")
    parser.add_argument("--model", type=str, required=True, help="Hugging Face model name (e.g., facebook/bart-large-cnn)")
    parser.add_argument("--input-file", type=str, required=True, help="Path to input dataset (csv/json/tsv or plain txt)")
    parser.add_argument("--input-column", type=str, default="text", help="Column name to read source text from for CSV/JSON/TSV datasets")
    parser.add_argument("--output-file", type=str, default="outputs.csv", help="CSV file to write results")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--max-length", type=int, default=142)
    parser.add_argument("--min-length", type=int, default=56)
    parser.add_argument("--num-beams", type=int, default=4)
    parser.add_argument("--device", type=str, default=None, help="'cpu' or 'cuda' or leave empty to auto-detect")
    parser.add_argument("--task", type=str, default="summarization", choices=["summarization", "generation"], help="Task type: summarization or generation")
    parser.add_argument("--do-sample", action="store_true", help="Enable sampling (stochastic generation)")
    parser.add_argument("--top-k", type=int, default=50)
    parser.add_argument("--top-p", type=float, default=1.0)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max-input-length", type=int, default=1024, help="Truncate inputs longer than this (tokens)")
    parser.add_argument("--progress", action="store_true", help="Show simple progress")
    return parser.parse_args()


def load_inputs(input_file: str, input_column: str = "text") -> List[Dict[str, str]]:
    """Load inputs into a list of dicts with key 'text' (and keep other columns if present).
    Supports: csv, tsv, json, txt (one example per line)
    """
    p = Path(input_file)
    suffix = p.suffix.lower()
    records = []

    if suffix in {".csv", ".tsv"}:
        delim = "\t" if suffix == ".tsv" else ","
        with open(p, newline='', encoding='utf-8') as f:
            reader = csv.DictReader(f, delimiter=delim)
            for row in reader:
                text = row.get(input_column)
                if text is None:
                    raise ValueError(f"Input column '{input_column}' not found in {input_file}. Columns: {list(row.keys())}")
                row["text"] = text
                records.append(row)
    elif suffix == ".json":
        # assume newline-delimited json or a list
        try:
            ds = load_dataset("json", data_files=str(p))
            for ex in ds["train"]:
                if input_column not in ex:
                    raise ValueError(f"Input column '{input_column}' not found in JSON dataset")
                records.append({**ex, "text": ex[input_column]})
        except Exception:
            # fallback: try line-by-line
            import json as _json
            with open(p, encoding='utf-8') as f:
                for line in f:
                    ex = _json.loads(line)
                    if input_column not in ex:
                        raise ValueError(f"Input column '{input_column}' not found in JSON lines file")
                    ex["text"] = ex[input_column]
                    records.append(ex)
    elif suffix == ".txt":
        with open(p, encoding='utf-8') as f:
            for i, line in enumerate(f):
                records.append({"text": line.strip(), "_line": i})
    else:
        # try datasets to infer format (works for many common formats)
        try:
            ds = load_dataset(str(p))
            for ex in ds["train"]:
                if input_column not in ex:
                    raise ValueError(f"Input column '{input_column}' not found in inferred dataset")
                records.append({**ex, "text": ex[input_column]})
        except Exception as e:
            raise ValueError(f"Unsupported file format or failed to load: {e}")

    if not records:
        raise ValueError("No examples loaded from input file")
    return records


def chunked(iterable, size):
    for i in range(0, len(iterable), size):
        yield iterable[i:i + size]


def generate_batch(
    texts: List[str],
    tokenizer: AutoTokenizer,
    model: AutoModelForSeq2SeqLM,
    device: torch.device,
    gen_kwargs: Dict,
    max_input_length: int = 1024,
):
    # Tokenize with truncation
    enc = tokenizer(
        texts,
        max_length=max_input_length,
        truncation=True,
        padding=True,
        return_tensors="pt",
    )
    input_ids = enc.input_ids.to(device)
    attention_mask = enc.attention_mask.to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gen_kwargs,
        )

    # decode
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return decoded

In [None]:
args = parse_args()

# device selection
if args.device:
    device = torch.device(args.device)
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Loading tokenizer and model: {args.model} on device={device}")
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model)

# move model to device (and optionally fp16 if using CUDA and supported)
model.to(device)
if device.type == "cuda":
    try:
        model.half()
        print("Converted model to float16 (half) for faster inference on CUDA")
    except Exception:
        pass

records = load_inputs(args.input_file, args.input_column)
texts = [r["text"] for r in records]

gen_kwargs = {
    "max_length": args.max_length,
    "min_length": args.min_length,
    "num_beams": args.num_beams,
    "do_sample": args.do_sample,
    "top_k": args.top_k if args.do_sample else None,
    "top_p": args.top_p if args.do_sample else None,
    "temperature": args.temperature if args.do_sample else None,
    # allow returning tensors for safety (we'll decode after)
}
# remove None values
gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}

batch_size = args.batch_size
results = []

total = len(texts)
for i, batch_idxs in enumerate(chunked(list(range(total)), batch_size)):
    batch_texts = [texts[j] for j in batch_idxs]
    outputs = generate_batch(batch_texts, tokenizer, model, device, gen_kwargs, args.max_input_length)
    for idx, out in zip(batch_idxs, outputs):
        rec = {**records[idx]}
        rec["generated"] = out
        results.append(rec)
    if args.progress:
        print(f"Processed {min((i + 1) * batch_size, total)}/{total}")

# write results to CSV (columns: original keys + generated)
out_keys = list(results[0].keys())
# ensure deterministic column order: put 'text' then 'generated' near end
if "text" in out_keys:
    out_keys = [k for k in out_keys if k != "text"]
    out_keys.insert(0, "text")
if "generated" in out_keys:
    out_keys = [k for k in out_keys if k != "generated"]
    out_keys.append("generated")

with open(args.output_file, "w", newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=out_keys)
    writer.writeheader()
    for r in results:
        writer.writerow({k: (r.get(k, "") if r.get(k, "") is not None else "") for k in out_keys})

print(f"Wrote {len(results)} records to {args.output_file}")