In [1]:
import torch
from peft import PeftModel
from transformers import LlamaTokenizer, LlamaForCausalLM

# Constants
MODEL = 'meta-llama/Llama-2-7b-chat-hf'
ADAPTER = 'rajeev-dw9/med_llama'
HF_TOKEN = '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'

# Function to perform inference
def generate_answer(model, tokenizer, prompt, max_new_tokens=512):
    with torch.no_grad():
        output_tensors = model.generate(
            input_ids=tokenizer(prompt, return_tensors="pt").input_ids.cuda(),
            max_new_tokens=max_new_tokens
        )[0]
    return tokenizer.decode(output_tensors, skip_special_tokens=True).split('### Answer')[-1]

# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(MODEL, legacy=False, use_auth_token=HF_TOKEN)

# Load base model
base_model = LlamaForCausalLM.from_pretrained(
    MODEL,
    device_map='auto',
    load_in_8bit=True,
    torch_dtype=torch.float16,
    use_auth_token=HF_TOKEN,
)

# Load PEFT adapted model
model_A = PeftModel.from_pretrained(
    base_model, ADAPTER, torch_dtype=torch.float16, is_trainable=False
)

  from .autonotebook import tqdm as notebook_tqdm
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.56s/it]


In [2]:
# Your prompt here
prompt = """
### Instruction
What is epilipsy? List genes responsible.
### Answer
"""

# Generate answer from base model
base_answer = generate_answer(base_model, tokenizer, prompt)

# Generate answer from PEFT adapted model
peft_answer = generate_answer(model_A, tokenizer, prompt)

print("Base Model Answer:", base_answer)
print("PEFT Adapted Model Answer:", peft_answer)


Base Model Answer: 
Epilepsy is a neurological disorder that affects the brain and is characterized by recurrent seizures. The genetic basis of epilepsy is complex and varied, with multiple genes and mechanisms involved. Here are some of the genes that have been implicated in epilepsy:

1. SCN1A - a gene that encodes the sodium channel protein Nav1.1, which is involved in the regulation of action potentials and is associated with familial epilepsy.
2. KCNQ2 - a gene that encodes the potassium channel protein Kv7.2, which is involved in the regulation of after-hyperpolarization and is associated with epileptic encephalopathy.
3. GRIN1 - a gene that encodes the glutamate receptor ion channel protein GRIN1, which is involved in the regulation of synaptic transmission and is associated with familial epilepsy.
4. CACNA1A - a gene that encodes the calcium channel protein CACNA1A, which is involved in the regulation of action potentials and is associated with familial epilepsy.
5. DEPDC1 - a 