### Test for HierRewardDataset

In [1]:
from datasets import Dataset, DatasetDict
import json
from transformers import AutoTokenizer

data = {
    "conversation": [
        "The patient reports feeling a bit unwell today, experiencing headache and fever. I asked if they had a cough.",
    ],
    "soap_note": [
        "The patient shows typical symptoms of a cold, suggesting rest and fluid intake.",
    ],
    "homework": [
        "Continue to monitor temperature, stay hydrated, and rest.",
    ]
}

instruction_templates = {
    "conversation": "Based on the following conversation, create a detailed medical SOAP note: {}",
    "soap_note": "Based on the following SOAP note, provide medical diagnosis suggestions: {}",
    "homework": "Based on the patient's diagnosis, generate appropriate SOAO note: {}"
}

instruction_template_file = "instruction_templates_en.json"
with open(instruction_template_file, 'w', encoding='utf-8') as f:
    json.dump(instruction_templates, f, ensure_ascii=False, indent=4)

tokenizer = AutoTokenizer.from_pretrained('OpenRLHF/Llama-3-8b-sft-mixture')

dataset = Dataset.from_dict(data)

class Strategy:
    def __init__(self, args):
        self.args = args

class Args:
    def __init__(self):
        self.conv_key = "conversation"
        self.soap_key = "soap_note"
        self.hw_key = "homework"
        self.apply_chat_template = True

args = Args()
strategy = Strategy(args)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer.chat_template

"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|start_header_id|>' + message['role'] + '<|end_header_id|>' + '\n' + message['content'] + '<|eot_id|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}{% endif %}"

In [3]:
from hierdpo.datasets import HierRewardDataset

dataset_obj = HierRewardDataset(
    dataset=dataset,
    tokenizer=tokenizer,
    strategy=strategy,
    instruction_template_file=instruction_template_file
)

num_proc must be <= 1. Reducing num_proc to 1 for dataset of size 1.
Map: 100%|██████████| 1/1 [00:00<00:00, 98.07 examples/s]


In [5]:
for i in range(len(dataset_obj)):
    prompt_soap_de, chosen_soap_de, prompt_soap_in, chosen_soap_in, prompt_hw, chosen_hw = dataset_obj[i]
    print(f"Prompt SOAP DE:\n{prompt_soap_de}\n")
    print(f"Chosen SOAP DE:\n{chosen_soap_de}\n")
    print("=" * 50)
    print(f"Prompt SOAP IN:\n{prompt_soap_in}\n")
    print(f"Chosen SOAP IN:\n{chosen_soap_in}\n")
    print(f"Prompt HW:\n{prompt_hw}\n")
    print(f"Chosen HW:\n{chosen_hw}\n")
    print("=" * 50)

Prompt SOAP DE:
<|start_header_id|>user<|end_header_id|>
Based on the following conversation, create a detailed medical SOAP note: The patient reports feeling a bit unwell today, experiencing headache and fever. I asked if they had a cough.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

Chosen SOAP DE:
<|start_header_id|>user<|end_header_id|>
Based on the following conversation, create a detailed medical SOAP note: The patient reports feeling a bit unwell today, experiencing headache and fever. I asked if they had a cough.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
The patient shows typical symptoms of a cold, suggesting rest and fluid intake.<|eot_id|>

Prompt SOAP IN:
<|start_header_id|>user<|end_header_id|>
Based on the patient's diagnosis, generate appropriate SOAO note: Continue to monitor temperature, stay hydrated, and rest.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

Chosen SOAP IN:
<|start_header_id|>user<|end_header_id|>
Based on the patien