In [19]:
# add the parent directory to the system path
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
from typing import List, Optional
import pandas as pd
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from huggingface_hub import notebook_login

In [32]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [27]:
# ── User-tunable parameters ────────────────────────────────────────────────────
DATA_FILE: str = r"C:\Users\norouzin\Desktop\JointLearning\datasets\expert_multi_task_data\test.csv"            # path to CSV dataset
PROMPT_FILE: str = r"C:\Users\norouzin\Desktop\JointLearning\src\causal_pseudo_labeling\prompt.txt"                         # path to prompt template
SAMPLE_SIZE: Optional[int] = None   # e.g. 100 -> sample 100 rows, None -> all
USE_CHAT_TEMPLATE: bool = True       # toggle chat wrapping on/off
BATCH_SIZE: int = 1000               # prompts per forward pass
MAX_TOKENS: int = 512                # max new tokens to generate
GPU_MEMORY_UTILISATION: float = 0.90 # fraction of GPU RAM vLLM may allocate
RANDOM_SEED: int = 8642              # reproducible sampling
SAVE_DIR: str = r"C:\Users\norouzin\Desktop\JointLearning\predictions"       # directory to save results
SENTENCE_COLUMN: int = 0          # column index of the sentence to be predicted

In [28]:
def load_prompt_template(path: str) -> str:
    """Read the prompt template containing the ``{{SENTENCE}}`` placeholder."""
    with open(path, "r", encoding="utf-8") as f:
        return f.read()
    

def load_sentences(path: str, column: int, sample_size: Optional[int]) -> List[str]:
    """Load sentences from a CSV file (optionally subsample)."""
    df = pd.read_csv(path)
    sentences = df.iloc[:, column].astype(str)

    if sample_size is not None:
        sentences = sentences.sample(n=sample_size, random_state=RANDOM_SEED)

    return sentences.tolist()


def build_prompts(template: str, sentences: List[str], use_chat: bool,
                  tokenizer: Optional[AutoTokenizer]) -> List[str]:
    """Return a list of formatted prompts ready for vLLM."""
    if use_chat and tokenizer is not None:
        messages = [[{"role": "user", "content": template.replace("{{SENTENCE}}", s)}]
                    for s in sentences]
        return [
            tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
            for m in messages
        ]
    # Fallback: plain string replacement (user supplies full template).
    return [template.replace("{{SENTENCE}}", s) for s in sentences]

# LLama3 8b

In [None]:
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"

In [None]:
print("Loading prompt template …")
prompt_template = load_prompt_template(PROMPT_FILE)

print("Loading sentences …")
sentences = load_sentences(DATA_FILE, SENTENCE_COLUMN, SAMPLE_SIZE)
print(f"Total sentences to annotate: {len(sentences):,}")

# Prepare tokenizer only if we intend to use chat formatting.
tokenizer = None
if USE_CHAT_TEMPLATE:
    print("Initialising tokenizer for chat template …")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

prompts = build_prompts(prompt_template, sentences, USE_CHAT_TEMPLATE, tokenizer)

sampling_params = SamplingParams(
        temperature=0.0,
        top_p=1.0,
        max_tokens=MAX_TOKENS,
    )

print("Initialising vLLM … (this may take a moment)")
llm = LLM(
        model=MODEL_NAME,
        dtype="float16",
        trust_remote_code=True,
        gpu_memory_utilization=GPU_MEMORY_UTILISATION,
    )

results: List[str] = []

    # Get the len of the prompts list.

num_prompts = len(prompts)
print(f"Total prompts to process: {num_prompts:,}")
print(f"Batch size: {BATCH_SIZE:,}")

print("Running inference …")

outputs = llm.generate(prompts, sampling_params)

results.extend([output.outputs[0].text for output in outputs])

print("Inference complete.")

Loading prompt template …
Loading sentences …
Total sentences to annotate: 452
