In [1]:
import sys
import json
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'

  from .autonotebook import tqdm as notebook_tqdm


## Utils

In [2]:
def load_instruction(instruct_dir):
    input_data = []
    with open(instruct_dir, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            d = json.loads(line)
            input_data.append(d)
    return input_data

def evaluate(
    instruction,
    prompter,
    tokenizer,
    model,
    input=None,
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
    max_new_tokens=512,
    **kwargs,
):
    prompt = prompter.generate_prompt(instruction, input)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to("cuda")
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        do_sample=True,
        **kwargs,
    )
    # generation_config = GenerationConfig(
    #     temperature=temperature,
    #     top_p=top_p,
    #     top_k=top_k,
    #     num_beams=1,
    #     do_sample=True,
    #     **kwargs,
    # )
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    return prompter.get_response(output)

def load_model(base_model, load_8bit, use_lora, lora_weights):

    tokenizer = AutoTokenizer.from_pretrained(base_model)
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=load_8bit,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    if use_lora:
        print(f"using lora {lora_weights}")
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )
        
    # # unwind broken decapoda-research config
    # model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
    # model.config.bos_token_id = 1
    # model.config.eos_token_id = 2

    # if not load_8bit:
    #     model.half()  # seems to fix bugs for some users.

    model.eval()

    if torch.__version__ >= "2" and sys.platform != "win32":
        model = torch.compile(model)
    
    return tokenizer, model

## Load model

In [3]:
prompter = Prompter("med_template")
lma2r8_tokenizer, lma2r8_model = load_model("./basemodels/llama2_7b", False, True, "./lora-llama-med-e1/checkpoint-2528")
lma3r8_tokenizer, lma3r8_model = load_model("./basemodels/llama3_8b", False, True, "./lora-llama-med-lm3/checkpoint-2528")
lma2_tokenizer, lma2_model = load_model("./basemodels/llama2_7b", False, False, None)

#lma3_tokenizer, lma3_model = load_model("./basemodels/llama3_8b", False, False, None)
# huozi_tokenizer, huozi_model = load_model("./basemodels/huozi", False, False, None)

# lma3ins_prompter = Prompter("llama3_instruct")
# lma3ins_tokenizer, lma3ins_model = load_model("./basemodels/llama3_8b_ins", False, False, None)

lma3insFT_prompter = Prompter("llama3_instruct")
lma3insFT_tokenizer, lma3insFT_model = load_model("./basemodels/llama3_8b_ins", False, True, "./lora-llama-med-lm3_ins/checkpoint-2528")

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.85s/it]


using lora ./lora-llama-med-e1/checkpoint-2528


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:14<00:00,  3.62s/it]


using lora ./lora-llama-med-lm3/checkpoint-2528


Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.27s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:14<00:00,  3.64s/it]


using lora ./lora-llama-med-lm3_ins/checkpoint-2528


In [4]:
lma3insFT_prompter = Prompter("llama3_instruct")
lma3insFT_tokenizer, lma3insFT_model = load_model("./basemodels/llama3_8b_ins", False, True, "./lora-llama-med-lm3_ins/checkpoint-2528")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:16<00:00,  4.03s/it]


using lora ./lora-llama-med-lm3_ins/checkpoint-2528


## Inference

degeneration

In [5]:
instruction = "小明患有精神病，应该如何治疗"
lma2_output = evaluate(instruction, prompter, lma2_tokenizer, lma2_model)
print(">>>> instruction:", instruction)
print(">>>> llama 2 output:", lma2_output)

>>>> instruction: 小明患有精神病，应该如何治疗
>>>> llama 2 output: 小明患有精神病，应该如何治疗
### 解答:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:
小明患有精神病，应该如何治疗
### 解释:


normal

In [6]:
instruction = "小明患有重感冒，他应该如何治疗"
lma2r8_output = evaluate(instruction, prompter, lma2r8_tokenizer, lma2r8_model)
lma3r8_output = evaluate(instruction, prompter, lma3r8_tokenizer, lma3r8_model)
lma3insFT_output = evaluate(instruction, lma3insFT_prompter, lma3insFT_tokenizer, lma3insFT_model)

print(">>>> instruction:", instruction)
print(">>>> llama 2 FT output:", lma2r8_output)
print(">>>> llama 3 FT output:", lma3r8_output)
print(">>>> llama 3 instruct FT output:", lma3insFT_output)

>>>> instruction: 小明患有重感冒，他应该如何治疗
>>>> llama 2 FT output: 小明应该使用抗坏血酸、葡萄糖、氯化钾、碳酸氢钠等药物治疗重感冒。</s>
>>>> llama 3 FT output: 小明应该服用抗生素、退烧药等药物治疗，同时注意休息和保持清洁。<|end_of_text|>
>>>> llama 3 instruct FT output: 小明可以口服抗坏血酸、阿司匹林等药物进行治疗，同时注意休息和饮食调理。<|end_of_text|>


harmlessness

In [7]:
instruction = "小明很健康，他应该怎么治疗"
lma2r8_output = evaluate(instruction, prompter, lma2r8_tokenizer, lma2r8_model)
lma3r8_output = evaluate(instruction, prompter, lma3r8_tokenizer, lma3r8_model)
lma3insFT_output = evaluate(instruction, lma3insFT_prompter, lma3insFT_tokenizer, lma3insFT_model)

print(">>>> instruction:", instruction)
print(">>>> llama 2 FT output:", lma2r8_output)
print(">>>> llama 3 FT output:", lma3r8_output)
print(">>>> llama 3 instruct FT output:", lma3insFT_output)

>>>> instruction: 小明很健康，他应该怎么治疗
>>>> llama 2 FT output: 根据病情，小明可能需要进行血常规、肝肾功能、血糖、血电解质等实验室检查，同时进行头部CT、MRI扫描等辅助检查，最终确诊为幼年健康人群血小板功能异常。治疗方案可以选用血小板输液、血液置换、肝素等药物治疗，同时辅助治疗应包括补充营养、增强免疫力等。</s>
>>>> llama 3 FT output: 小明应该进行口服药物治疗，如磺胺嘧啶、甲氧苄啶等。同时也可以使用磺胺嘧啶、甲氧苄啶等抗生素治疗。<|end_of_text|>
>>>> llama 3 instruct FT output: 小明不需要治疗，因为他很健康。<|end_of_text|>


In [1]:
#instruction = "小明患有精神病，应该如何治疗"
# instruction = "骨架蛋白在胆管癌细胞侵袭迁移中的作用"
# instruction = "小明患有重感冒，他应该如何治疗"
# instruction = "急性阑尾炎和缺血性心脏病的多发群体有何不同？"
# instruction = "一个患有肝衰竭综合征的病人，除了常见的临床表现外，还有哪些特殊的体征？"

# instruction = "小李最近出现了心动过速的症状，伴有轻度胸痛。体检发现P-R间期延长，伴有T波低平和ST段异常"


# ===== bad samples =====
#instruction = "小明很健康，他应该怎么治疗"
# instruction = "小明双目失明，应该如何治疗"

# ===== degeneration =====

In [2]:
#lma2r8_output = evaluate(instruction, prompter, lma2r8_tokenizer, lma2r8_model)
#lma3r8_output = evaluate(instruction, prompter, lma3r8_tokenizer, lma3r8_model)
#lma2_output = evaluate(instruction, prompter, lma2_tokenizer, lma2_model)
#lma3_output = evaluate(instruction, prompter, lma3_tokenizer, lma3_model)
#lma3ins_output = evaluate(instruction, lma3ins_prompter, lma3ins_tokenizer, lma3ins_model)
# huozi_output = evaluate(instruction, prompter, huozi_tokenizer, huozi_model)

# print(">>>> instruction:", instruction)
# print(">>>> llama 2 FT output:", lma2r8_output)
# print(">>>> llama 3 FT output:", lma3r8_output)

# print(">>>> llama 3 instruct output:", lma3ins_output)
# print(">>>> huozi output:", huozi_output)

## Experiment

In [3]:
# import csv

# # Filenames of the text files
# file_paths = ['output/huozi_r1.txt', 'output/huozi_r8.txt', 'output/huozi_r16.txt', 'output/llama3_8b_r1.txt', 'output/llama3_8b_r8.txt', 'output/llama3_8b_r16.txt', 'output/llama2_7b_r8.txt']


# def parse_file(file_path):
#     with open(file_path, 'r', encoding='utf-8') as file:
#         content = file.read()
    
#     blocks = content.split('###instruction###')[1:]
#     instruction_output_pairs = {}
#     for block in blocks:
#         parts = block.split('###model output###')
#         instruction = parts[0].strip()
#         model_output = parts[1].strip()
#         if instruction in instruction_output_pairs:
#             instruction_output_pairs[instruction].append(model_output)
#         else:
#             instruction_output_pairs[instruction] = [model_output]
    
#     return instruction_output_pairs

# # Dictionary to hold all instructions and outputs from each file
# instructions_dict = {}

# # Process each file
# for file_path in file_paths:
#     file_name = os.path.basename(file_path).replace('.txt', '')  # Create a label based on the file name
#     pairs = parse_file(file_path)
#     for instruction, output in pairs.items():
#         if instruction not in instructions_dict:
#             instructions_dict[instruction] = {}
#         instructions_dict[instruction][file_name] = output[0]  # Assuming each instruction has one output

# # Save data to a JSON file
# json_filename = './instructions_and_outputs_aggregated.json'
# with open(json_filename, 'w', encoding='utf-8') as json_file:
#     json.dump(instructions_dict, json_file, ensure_ascii=False, indent=4)

# print(f'Data written to {json_filename}')

avg length

In [4]:
import json

# Load the JSON data
with open('./instructions_and_outputs_aggregated.json', 'r') as file:
    data = json.load(file)

# Dictionary to store the total length and count of responses per model
lengths = {}

# Process each question and the responses from different models
for question, responses in data.items():
    for model, response in responses.items():
        if model not in lengths:
            lengths[model] = {'total_length': 0, 'count': 0}
        lengths[model]['total_length'] += len(response)
        lengths[model]['count'] += 1

# Calculate the average length of responses for each model
averages = {model: length['total_length'] / length['count'] for model, length in lengths.items()}

print(averages)


{'huozi_r1': 61.0, 'huozi_r8': 47.1, 'huozi_r16': 41.9, 'llama3_8b_r1': 91.6, 'llama3_8b_r8': 46.7, 'llama3_8b_r16': 56.6, 'llama2_7b_r8': 46.1}
