In [1]:
import json
import os
from glob import glob
from typing import List
from jinja2 import Template, Environment, FileSystemLoader, meta, BaseLoader

def read_json_from_path(file_path):
    """
    Reads a JSON file from the specified file path and returns the data.

    Parameters:
    file_path (str): The path to the JSON file.

    Returns:
    dict: The data contained in the JSON file.
    """
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)
        return data
    except FileNotFoundError:
        print(f"The file at {file_path} was not found.")
    except json.JSONDecodeError:
        print(f"The file at {file_path} is not a valid JSON file.")
    except Exception as e:
        print(f"An error occurred: {e}")


class TemplateLoader:
    """
    A class for loading and managing Jinja2 templates. It allows loading templates from files or strings,
    listing available templates, and getting template variables.
    """

    def __init__(self):
        """
        Initialize the TemplateLoader object and create an empty dictionary for loaded templates.
        """
        self.loaded_templates = {}

    def to_letter(self, index):
        return chr(65 + index)

    def load_template(
        self, template: str, from_string: bool = False
    ):
        """
        Load a Jinja2 template either from a string or a file.

        Args:
            template (str): Template string or path to the template file.
            from_string (bool): Whether to load the template from a string. Defaults to False.

        Returns:
            dict: Loaded template data.
        """
        if template in self.loaded_templates:
            return self.loaded_templates[template]

        if from_string:
            template_instance = Environment(loader=BaseLoader())
            template_instance.filters['to_letter'] = self.to_letter
            template = template_instance.from_string(template)
            
            template_data = {
                "template_name": "from_string",
                "template_dir": None,
                "environment": template_instance,
                "template":    template,
            }
            
        else:
            template_data = self._load_template_from_path(template)
        self.loaded_templates[template] = template_data
        return self.loaded_templates[template]

    def _load_template_from_path(self, template: str) -> dict:
        """
        Load a Jinja2 template from the given path.

        Args:
            template (str): Path to the template file.

        Returns:
            dict: Loaded template data.
        """
        
        self._verify_template_path(template)
        custom_template_dir, custom_template_name = os.path.split(template)
        environment = Environment(loader=FileSystemLoader(custom_template_dir))
        environment.filters['to_letter'] = self.to_letter
        template_instance = environment.get_template(custom_template_name)

        return {
            "template_name": custom_template_name,
            "template_dir":  custom_template_dir,
            "environment":   environment,
            "template":      template_instance}

    def _verify_template_path(self, templates_path: str):
        """
        Verify the existence of the template file.

        Args:
            templates_path (str): Path to the template file.

        Raises:
            ValueError: If the template file does not exist.
        """
        if not os.path.isfile(templates_path):
            raise ValueError(f"Templates path {templates_path} does not exist")

In [25]:
import random

class MedMCQA:

    def __init__(self, instruction,method, n_example):
        supported_methods      = {'fewshots': 'few_shots/', 'cot' : 'cot/', 'er': 'cot/', 'sc': 'cot/'}
        self.examples          = read_json_from_path(f"{supported_methods[method]}examples.json")['examples']
        self.sampling_         = random.sample(self.examples, k = n_example)
        self.jinja_template    = TemplateLoader().load_template(f"{supported_methods[method]}prompt.jinja")
        

    def get_prompt(self, sample):

        sample_id      = sample['id']
        correct_answer = sample['correct_answer']
        correct_option = sample['correct_option']
        mcq_question   = sample["question"]
        mcq_choices    = sample[ "choices"]

        prompt_full = self.jinja_template['template'].render(instruction = instruction, question = mcq_question, choices = mcq_choices, examples = self.sampling_)
        return prompt_full




import random
import json

class MedMCQA:
    """
    A class to generate prompts for medical multiple-choice questions (MCQs)
    using various methods like few-shots, chain-of-thought (cot), etc.

    Attributes:
        instruction (str): The instruction or task description.
        method (str): The method used for generating prompts (e.g., 'fewshots', 'cot').
        n_example (int): The number of examples to use for generating the prompt.
    """

    def __init__(self, instruction, method, n_example):
        self.instruction = instruction
        self.method = method
        self.n_example = n_example

        # Define supported methods and their directories
        supported_methods = {'fewshots': 'few_shots/', 'cot': 'cot/', 'er': 'cot/', 'sc': 'cot/'}
        method_path = supported_methods.get(method)

        if not method_path:
            raise ValueError(f"Method '{method}' is not supported. Choose from {list(supported_methods.keys())}.")

        # Load examples and template
        self.examples = self._read_json_from_path(f"{method_path}examples.json")['examples']
        self.sampling = random.sample(self.examples, k=n_example)
        self.jinja_template    = TemplateLoader().load_template(f"{supported_methods[method]}prompt.jinja")
        
    def _read_json_from_path(self, path):
        """Utility method to read JSON from a given path."""
        try:
            with open(path, 'r') as file:
                return json.load(file)
        except FileNotFoundError:
            raise FileNotFoundError(f"File {path} not found.")
        except json.JSONDecodeError:
            raise ValueError(f"Error decoding JSON from the file {path}.")

    def get_prompt(self, sample):
        """
        Generates a prompt for a given sample MCQ.

        Args:
            sample (dict): A dictionary containing MCQ information.

        Returns:
            str: A formatted prompt string.
        """
        prompt_full = self.jinja_template['template'].render(
            instruction=self.instruction,
            question=sample["question"],
            choices=sample["choices"],
            examples=self.sampling
        )
        return prompt_full
        

In [26]:
instruction = "seldct correct option?"

sample = {"id" : 1212113,
 "instruction": "Select the correct answer from the choices below.",
 "question": "What is the capital of France?",
 "choices": ["Berlin", "Paris", "Madrid", "Lisbon"], 
 'correct_answer': 'answer',
 'correct_option' : 'D'}

examp  = [{'question': 'What is the capital of Balalalal', 'choices': ['a', 'b', 'c', 'd', 'e'], 'answer': 'a'},
                 {'question': 'What is the capital of Balalalal', 'choices': ['a', 'b', 'c', 'd', 'e'], 'answer': 'b'},
                 {'question': 'What is the capital of Balalalal', 'choices': ['a', 'b', 'c', 'd', 'e'], 'answer': 'c'}]


mlk = MedMCQA(instruction, 'cot', 5)
print(mlk.get_prompt(sample))

seldct correct option?

Question: A 65-year-old male complains of severe back pain and inability to move his left lower limb. Radiographic studies demonstrate the compression of nerve elements at the intervertebral foramen between vertebrae L5 and S1. Which structure is most likely responsible for this space-occupying lesion?
Choices:
A. Anulus fibrosus
B. Nucleus pulposus
C. Posterior longitudinal ligament
D. Anterior longitudinal ligament
Explanation: Let’s solve this step-by-step, referring to authoritative sources as needed. This man describes a herniated intervertebral disk through a tear in the surrounding annulus fibrosus. The soft, gelatinous 'nucleus pulposus' is forced out through a weakened part of the disk, resulting in back pain and nerve root irritation. In this case, the impingement is resulting in paralysis and should be considered a medical emergency. Overall, the structure that is causing the compression and symptoms is the nucleus pulposus.
Answer_text: Nucleus pulpo

In [11]:
mk = MedMCQA('cot', 2)
data = {
    "description": "This is a sample multiple-choice question for demonstration.",
    "instruction": "Select the correct answer from the choices below.",
    "question": "What is the capital of France?",
    "choices": ["Berlin", "Paris", "Madrid", "Lisbon"],
    "examples": [{'question': 'What is the capital of Balalalal', 'choices': ['a', 'b', 'c', 'd', 'e'], 'answer': 'a'},
                 {'question': 'What is the capital of Balalalal', 'choices': ['a', 'b', 'c', 'd', 'e'], 'answer': 'b'},
                 {'question': 'What is the capital of Balalalal', 'choices': ['a', 'b', 'c', 'd', 'e'], 'answer': 'c'}]}




hello
ahh

Question: What is the capital of Balalalal
Choices:
A. a
B. b
C. c
D. d
E. e
Answer: a

Question: What is the capital of Balalalal
Choices:
A. a
B. b
C. c
D. d
E. e
Answer: b

Question: What is the capital of Balalalal
Choices:
A. a
B. b
C. c
D. d
E. e
Answer: c

Question: 
Choices:
Explanation: 
Answer: 


In [9]:
mk.jinja_template

{'template_name': 'prompt.jinja',
 'template_dir': 'cot',
 'environment': <jinja2.environment.Environment at 0x1079c8df0>,
 'template': <Template 'prompt.jinja'>}