In [2]:
import vllm
import argparse
import json
from tqdm import tqdm
from pathlib import Path

def generate(model, prompts, batch_size=32, **kwargs):
    params = vllm.SamplingParams(**kwargs)
    llm = vllm.LLM(model=model)
    prompts_split_by_batch = [prompts[i:i + batch_size] for i in range(0, len(prompts), batch_size)]
    outputs = []
    for batch in tqdm(prompts_split_by_batch, desc="batch", total=len(prompts_split_by_batch)):
        batch_text = [b["prompt"] for b in batch]
        output = llm.generate(batch_text, params)
        outputs.extend(output)
    return outputs

def read_prompts(prompt_file: Path):
    with open(prompt_file, 'r') as f:
        prompts = [json.loads(line) for line in f]
    return prompts

def produce_output(output_file: Path, outputs):
    with open(output_file, 'w') as f:
        f.write(json.dumps(outputs))

    

tensor([[1.1501, 0.0680, 1.2978],
        [1.6387, 1.3281, 1.3818],
        [1.7905, 0.1017, 1.6345],
        [0.5013, 0.3602, 1.7003],
        [0.8587, 0.9017, 0.8180]], device='cuda:0')


In [None]:
prompt_file = "./nate-dataset/sp14/pairs.json"
model_dir = "/work/arjunguha-research-group/arjun/models/starcoderbase-1b"
prompts = read_prompts()
outputs = generate(model_dir, prompts, 32, temperature=0.8, top_p=0.9)
produce_output(args.output_file, outputs)