In [1]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import torch
# Installation guide for cu11.8 and cu12.1 https://docs.vllm.ai/en/latest/getting_started/installation.html
from vllm import LLM
from vllm import LLM, SamplingParams
from llmformat.llminterface import build_vllm_logits_processor

torch.cuda.manual_seed(42)
torch.manual_seed(42)

def load_model(model_dir, tp_size=1):
    llm = LLM(model=model_dir, tensor_parallel_size=tp_size)
    return llm

default_prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

def get_prompt(message: str):
    return f'<s>[INST] <<SYS>>\n{default_prompt}\n<</SYS>>\n\n{message} [/INST]'

In [2]:
def main(
    model,
    max_new_tokens=100,
    user_prompt=None,
    top_p=0.9,
    temperature=0.8
):
    while True:
        if user_prompt is None:
            user_prompt = input("Enter your prompt: ")
        
        print(f"User prompt:\n{user_prompt}")            
        user_prompt = get_prompt(user_prompt)
        

        print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
        sampling_param = SamplingParams(top_p=top_p, 
                                        temperature=temperature, 
                                        max_tokens=max_new_tokens,
                                       )
        sampling_param.logits_processors=[
                                            #build_vllm_logits_processor(model, "/root/llmformat/llmformat/json_min.bnf")
                                        ]

        outputs = model.generate(user_prompt, sampling_params=sampling_param)
   
        print(f"model output:\n {outputs[0].outputs[0].text}")
        user_prompt = input("Enter next prompt (press Enter to exit): ")
        if not user_prompt:
            break

def run_script(
    model_dir,
    tp_size=1,
    max_new_tokens=300,
    user_prompt=None,
    top_p=0.9,
    temperature=0.8
):
    model = load_model(model_dir, tp_size)
    main(model, max_new_tokens, user_prompt, top_p, temperature)


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, PeftModel
run_script("meta-llama/Llama-2-7b-hf", user_prompt="Represents a=10 b=20 c=30 in json's format.")

INFO 12-25 05:40:56 llm_engine.py:73] Initializing an LLM engine with config: model='meta-llama/Llama-2-7b-hf', tokenizer='meta-llama/Llama-2-7b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=False, seed=0)
INFO 12-25 05:41:01 llm_engine.py:223] # GPU blocks: 726, # CPU blocks: 512
INFO 12-25 05:41:02 model_runner.py:394] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 12-25 05:41:09 model_runner.py:437] Graph capturing finished in 7 secs.
User prompt:
Represents a=10 b=20 c=30 in json's format.
sampling params: top_p 0.9 and temperature 0.8 for this inference request


Processed prompts: 100%|██████████| 1/1 [00:09<00:00,  9.21s/it]


model output:
 

[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Represents a=10 b=20 c=30 in json's format. [/INST]

[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain wh