In [None]:
# load input text sentences
from pandas import read_json, Series
CTD_RE_V1 = read_json('../label_studio/export/CTD_RE_v1.json').set_index('id')
sentences = Series(data = [row['text'] for row in CTD_RE_V1.data], index=CTD_RE_V1.index)

# load test sample ids
from csv import reader
with open("test_output_2000/sampled_test_ids.csv", "r") as file:
    sampled_test_ids = list(map(int, list(reader(file, delimiter=","))[0]))
    file.close()

# load base model
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

#base_model_id = "/mnt/sdc/llama_hf/llama-2-7b-hf"
base_model_id = "/home/qyfeng/llama_hf/Meta-Llama-3-8B"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,  # Llama 2 7b, same as finetuning
    quantization_config=bnb_config,  # Same quantization config as finetuning
    device_map="auto",
    trust_remote_code=True,
)

eval_tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    add_bos_token=True,
    trust_remote_code=True,
)

In [None]:
from peft import PeftModel
ft_model_260 = PeftModel.from_pretrained(base_model, "finetuned_models/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain/checkpoint-260")
ft_model_450 = PeftModel.from_pretrained(base_model, "finetuned_models/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain/checkpoint-450")
ft_model_670 = PeftModel.from_pretrained(base_model, "finetuned_models/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain/checkpoint-670")

In [None]:
from peft import PeftModel
ft_model_260 = PeftModel.from_pretrained(base_model, "finetuned_models/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain/checkpoint-260")
ft_model_450 = PeftModel.from_pretrained(base_model, "finetuned_models/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain/checkpoint-450")
ft_model_670 = PeftModel.from_pretrained(base_model, "finetuned_models/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain/checkpoint-670")

In [None]:
# load helper functions
def format_relation(relations):
    relation_str = ""
    for relation in relations:
        relation_str += ("((" + relation['subject_entity']['entity_name'] + ", " + relation['subject_entity']['entity_type']+ "), " +
                            relation['relation_phrase']+ ", " +
                            "(" + relation['object_entity']['entity_name'] + ", " + relation['object_entity']['entity_type']+ "))" + "; ")
        
    return relation_str
def formatting_func_v3(data_point):
    full_prompt = f"""Given an input text sentence, extract fact relations.
    Each fact relation describes a scientific observation or hypothesis and is in the format of a triple connecting two entities via a relation phrase: (subject_entity, relation_phrase, object_entity).
    Each subject_entity or object_entity is a chemical compound or gene/protein and is in the format of a 2-tuple: (entity_name, entity_type). Depending on the type of the entity, the entity_type must be one of ['Chemical', 'Gene/Protein'].
    The relation_phrase must be one of the following: ['increases', 'decreases', 'affects', 'binds'].
    The extracted relations should be a semicolon-separated list of relations in the format of triples: ((entity_name, entity_type), relation_phrase, (entity_name, entity_type)).
    
    ### Input sentence:
    {data_point["input_sentence"]}

    ### Extracted relations:
    {format_relation(data_point["relations"])}
    """
    return full_prompt

def generate_text(input_sentence, model):
    eval_example = {'input_sentence': input_sentence, 'relations': []}
    eval_prompt = formatting_func_v3(eval_example)
    model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")

    model.eval()
    with torch.no_grad():
        output_text = eval_tokenizer.decode(model.generate(**model_input, max_new_tokens=512)[0], skip_special_tokens=True)
    return output_text

In [None]:
from tqdm import tqdm
for task_id in tqdm(sampled_test_ids):
    with open('test_output_2000/claude/llama3-8b-CTD_RE_V1-finetune-r_8_la_32-prompt_v3-random_2000-claude_langchain-checkpoint-260/'+str(task_id)+'.txt', "w") as outfile:
        outfile.write(generate_text(sentences[task_id], ft_model_260))
        outfile.close()