In [None]:
!pip install langchain==0.1.16 openai==1.14.2 langchain-openai==0.1.1 langchain-google-genai==1.0.3 langchain-community==0.0.34 faiss-gpu==1.7.2 sentence_transformers==2.7.0 huggingface-hub==0.23.0 transformers==4.40.2 accelerate==0.30.1 pandas==2.0.3
!CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.2.63

In [None]:
import ast
import os
import shutil
import tarfile
from datetime import datetime

import pandas as pd
import requests
import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from huggingface_hub import hf_hub_download, logging as hf_logging
from langchain.agents.agent import AgentExecutor, AgentType
from langchain.agents.agent_toolkits import VectorStoreToolkit, VectorStoreInfo
from langchain.agents.agent_toolkits.openapi import planner
from langchain.agents.agent_toolkits.openapi.spec import reduce_openapi_spec
from langchain.agents.agent_toolkits.vectorstore.prompt import PREFIX as VECTORSTORE_AGENT_PREFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.output_parsers.json import parse_json_markdown
from langchain_community.llms import LlamaCpp
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
from langchain_google_genai import ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory
from langchain_openai import ChatOpenAI
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, pipeline

In [None]:
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast

import yaml
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool, Tool

from langchain_community.agent_toolkits.openapi.planner_prompt import (
    API_CONTROLLER_TOOL_DESCRIPTION,
    API_CONTROLLER_TOOL_NAME,
    API_ORCHESTRATOR_PROMPT,
    API_PLANNER_PROMPT,
    API_PLANNER_TOOL_DESCRIPTION,
    API_PLANNER_TOOL_NAME,
    PARSING_GET_PROMPT,
    REQUESTS_GET_TOOL_DESCRIPTION,
    REQUESTS_POST_TOOL_DESCRIPTION,
)
from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain_community.llms import OpenAI
from langchain_community.tools.requests.tool import BaseRequestsTool
from langchain_community.utilities.requests import RequestsWrapper

In [None]:
API_CONTROLLER_PROMPT = """You are an agent that gets a sequence of API calls and given their documentation, should execute them and return the final response.
If you cannot complete them and run into issues, you should explain the issue. If you're unable to resolve an API call, you can retry the API call. When interacting with API objects, you should extract ids for inputs to other API calls but ids and names for outputs returned to the User.


Here is documentation on the API:
Base url: {api_url}
Endpoints:
{api_docs}


Here are tools to execute requests against the API: {tool_descriptions}


Starting below, you should follow this format:

Plan: the plan of API calls to execute
Thought: you should always think about what to do
Action: the action to take, MUST be exactly the name of one of the tools [{tool_names}]
Action Input: the input to the action
Observation: the output of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I am finished executing the plan (or, I cannot finish executing the plan without knowing some other information.)
Final Answer: the final output from executing the plan or missing information I'd need to re-plan correctly.


Begin!

Plan: {input}
Thought:
{agent_scratchpad}
"""

In [None]:
def get_athena_openapi_docs(target_vocabulary):
    vocabulary_parameter = ''
    if target_vocabulary == 'SNOMED':
        vocabulary_parameter = """- name: vocabulary
          in: query
          description: Specify the vocabulary.
          required: true
          schema:
            type: string
          enum:
            - SNOMED"""
    elif target_vocabulary == 'RxNorm':
        vocabulary_parameter = """- name: vocabulary
          in: query
          description: Specify the vocabulary(ies).
          required: true
          style: form
          explode: true
          schema:
            type: array
            items:
              type: string
              enum:
                - RxNorm
                - RxNorm Extension
            minItems: 2"""
    return f"""openapi: 3.0.3
servers:
  - url: https://athena.ohdsi.org/api/v1
info:
  title: Athena OHDSI API
  version: 1.0.0
  description: Athena OHDSI API for searching concepts in the {target_vocabulary} vocabulary
paths:
  /concepts:
    get:
      summary: Search {target_vocabulary} Concepts
      description: Endpoint for searching concepts in the {target_vocabulary} vocabulary based on a queried term
      parameters:
        - name: query
          in: query
          description: Pass a text query to search.
          required: true
          schema:
            type: string
          minLength: 1
        - name: pageSize
          in: query
          description: Specify the number of results per page.
          required: true
          schema:
            type: integer
          minimum: 15
          maximum: 15
        - name: page
          in: query
          description: Specify which page to query.
          required: true
          schema:
            type: integer
        - name: standardConcept
          in: query
          description: Specify the standard concept.
          required: true
          schema:
            type: string
          enum:
            - Standard
        {vocabulary_parameter}
        - name: invalidReason
          in: query
          description: Specify the invalid reason.
          required: true
          schema:
            type: string
          enum:
            - Valid
      responses:
        '200':
          description: Successful response
          content:
            application/json:
              schema:
                type: object
                properties:
                  number:
                    type: integer
                    description: The queried page
                    required: true
                  totalElements:
                    type: integer
                    description: Total amount of results
                    required: true
                  totalPages:
                    type: integer
                    description: Total amount of pages
                    required: true
                  content:
                    type: array
                    items:
                      type: object
                      properties:
                        name:
                          type: string
                          description: Concept name
                          required: true
                        id:
                          type: integer
                          description: Concept ID
                          required: true
                        code:
                          type: string
                          description: Concept code
                          required: true
                        className:
                          type: string
                          description: The class of the concept
                          required: true
                        domain:
                          type: string
                          description: The domain of the concept
                          required: true
                        invalidReason:
                          type: string
                          nullable: true
                          description: The reason the concept is invalid (if it is invalid)
                          required: true
                        score:
                          type: number
                          nullable: true
                          description: The score of the concept
                          required: true
                        standardConcept:
                          type: string
                          description: Standard concept
                          required: true
                        vocabulary:
                          type: string
                          description: The name of the vocabulary the concept is in
                          required: true"""

In [None]:
# Adapted from LangChain
MAX_RESPONSE_LENGTH = 5000

def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
    from langchain.chains.llm import LLMChain

    return LLMChain(
        llm=OpenAI(),
        prompt=prompt,
    )

def _get_default_llm_chain_factory(
    prompt: BasePromptTemplate,
) -> Callable[[], Any]:
    """Returns a default LLMChain factory."""
    return partial(_get_default_llm_chain, prompt)

class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
    """Requests GET tool with LLM-instructed extraction of truncated responses."""

    name: str = "requests_get"
    """Tool name."""
    description = REQUESTS_GET_TOOL_DESCRIPTION
    """Tool description."""
    response_length: int = MAX_RESPONSE_LENGTH
    """Maximum length of the response to be returned."""
    llm_chain: Any = Field(
        default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
    )
    """LLMChain used to extract the response."""

    def _run(self, text: str) -> str:
        from langchain.output_parsers.json import parse_json_markdown

        try:
            data = parse_json_markdown(text)
        except json.JSONDecodeError as e:
            raise e
        data_params = data.get("params")
        requests_response = self.requests_wrapper.get(data["url"], params=data_params)
        requests_response = {k: v for k, v in requests_response.items() if k == 'content'}
        requests_response = [{'id': concept.get('id'), 'name': concept.get('name')} for concept in requests_response.get('content', [])]
        response: str = json.dumps(requests_response)
        response = response[: self.response_length]
        print('API response: ', response)
        return response.strip()

    async def _arun(self, text: str) -> str:
        raise NotImplementedError()

def _create_api_controller_agent(
    api_url: str,
    api_docs: str,
    requests_wrapper: RequestsWrapper,
    llm: BaseLanguageModel,
) -> Any:
    from langchain.agents.agent import AgentExecutor
    from langchain.agents.mrkl.base import ZeroShotAgent
    from langchain.chains.llm import LLMChain

    get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)

    tools: List[BaseTool] = [
        RequestsGetToolWithParsing(
            requests_wrapper=requests_wrapper, llm_chain=get_llm_chain, allow_dangerous_requests=True
        ),
    ]
    prompt = PromptTemplate(
        template=API_CONTROLLER_PROMPT,
        input_variables=["input", "agent_scratchpad"],
        partial_variables={
            "api_url": api_url,
            "api_docs": api_docs,
            "tool_names": ", ".join([tool.name for tool in tools]),
            "tool_descriptions": "\n".join(
                [f"{tool.name}: {tool.description}" for tool in tools]
            ),
        },
    )
    agent = ZeroShotAgent(
        llm_chain=LLMChain(llm=llm, prompt=prompt),
        allowed_tools=[tool.name for tool in tools],
    )
    return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)

def _create_api_controller_tool(
    api_spec: ReducedOpenAPISpec,
    requests_wrapper: RequestsWrapper,
    llm: BaseLanguageModel,
) -> Tool:
    """Expose controller as a tool.

    The tool is invoked with a plan from the planner, and dynamically
    creates a controller agent with relevant documentation only to
    constrain the context.
    """

    base_url = api_spec.servers[0]["url"]

    def _create_and_run_api_controller_agent(plan_str: str) -> str:
        pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
        matches = re.findall(pattern, plan_str)
        endpoint_names = [
            "{method} {route}".format(method=method, route=route.split("?")[0])
            for method, route in matches
        ]
        docs_str = ""
        for endpoint_name in endpoint_names:
            found_match = False
            for name, _, docs in api_spec.endpoints:
                regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
                if regex_name.match(endpoint_name):
                    found_match = True
                    docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
            if not found_match:
                raise ValueError(f"{endpoint_name} endpoint does not exist.")

        agent = _create_api_controller_agent(base_url, docs_str, requests_wrapper, llm)
        return agent.run(plan_str)

    return Tool(
        name=API_CONTROLLER_TOOL_NAME,
        func=_create_and_run_api_controller_agent,
        description=API_CONTROLLER_TOOL_DESCRIPTION,
    )

def _create_api_planner_tool(
    api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
) -> Tool:
    from langchain.chains.llm import LLMChain

    endpoint_descriptions = [
        f"{name} {description}" for name, description, _ in api_spec.endpoints
    ]
    prompt = PromptTemplate(
        template=API_PLANNER_PROMPT,
        input_variables=["query"],
        partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
    )
    chain = LLMChain(llm=llm, prompt=prompt)
    tool = Tool(
        name=API_PLANNER_TOOL_NAME,
        description=API_PLANNER_TOOL_DESCRIPTION,
        func=chain.run,
    )
    return tool

In [None]:
from langchain_core.callbacks import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.vectorstores import VectorStore

In [None]:
class BaseAPIRequestTool(BaseModel):

    llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
    url: str = Field(exclude=True)
    headers: dict = Field(exclude=True)
    data_params: dict = Field(exclude=True)
    response_key: str = Field(exclude=True)

    class Config(BaseTool.Config):
        pass


class APIRequestQAWithSourcesTool(BaseAPIRequestTool, BaseTool):

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        params = {'query': query, **self.data_params}
        response = requests.get(self.url, params=params, headers=self.headers)
        response_json = response.json()
        response_content = {k: v for k, v in response_json.items() if k == self.response_key}
        results = [
            {
                'id': str(result.get('id')),
                'name': result.get('name'),
            }
            for result in response_content.get(self.response_key, [])
        ]
        return json.dumps(results)

    async def _arun(
        self,
        query: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        pass

In [None]:
class BaseVectorStoreTool(BaseModel):

    vectorstore: VectorStore = Field(exclude=True)
    llm: BaseLanguageModel = Field(default_factory=lambda: OpenAI(temperature=0))
    target_id_column: str = Field(exclude=True)
    target_candidate_column: str = Field(exclude=True)

    class Config(BaseTool.Config):
        pass


class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        results = self.vectorstore.similarity_search_with_score(query, k=15)

        results = [
            {
                'id': result[0].metadata.get(self.target_id_column),
                'name': result[0].metadata.get(self.target_candidate_column) or result[0].page_content
            }
            for result in results
        ]

        return json.dumps(results)

    async def _arun(
        self,
        query: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        pass

In [None]:
class LlamaCppWithUsage(LlamaCpp):

    prompt_tokens = 0
    completion_tokens = 0
    total_tokens = 0

    def _call(self, prompt, stop=None, run_manager=None, **kwargs):
        if self.streaming:
            combined_text_output = ''
            for chunk in self._stream(
                prompt=prompt,
                stop=stop,
                run_manager=run_manager,
                **kwargs,
            ):
                combined_text_output += chunk.text
            return combined_text_output
        else:
            params = self._get_parameters(stop)
            params = {**params, **kwargs}
            result = self.client(prompt=prompt, **params)
            self.prompt_tokens = result['usage'].get('prompt_tokens', 0)
            self.completion_tokens = result['usage'].get('completion_tokens', 0)
            self.total_tokens = result['usage'].get('total_tokens', 0)
            return result['choices'][0]['text']


In [None]:
def call_baseline(df_row, tools, db, template, target_vocabulary, source_selection_column, source_description_column,
                   source_id_column, source_id_key, source_name_key, source_equivalence_key, target_id_column, target_candidate_column):
    if 'llm_output' in df_row and df_row.llm_output != 'Error':
        return (df_row.llm_output, df_row.llm_prompt, df_row.llm_intermediate_steps, df_row.llm_concept, df_row.llm_name, df_row.llm_equivalence,
                df_row.partial_match, df_row.exact_match, df_row.prompt_tokens, df_row.completion_tokens, df_row.total_tokens)

    source_concept = df_row[source_selection_column].strip().capitalize()
    if 'athena-api' in tools:
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36',
        }
        data_params = {
            'query': source_concept,
            'pageSize': 15,
            'page': 1,
            'standardConcept': 'Standard',
            'vocabulary': 'SNOMED' if target_vocabulary == 'SNOMED' else ['RxNorm', 'RxNorm Extension'],
            'invalidReason': 'Valid',
        }
        requests_response = requests.get('https://athena.ohdsi.org/api/v1/concepts', params=data_params, headers=headers)
        requests_response = requests_response.json()
        requests_response = {k: v for k, v in requests_response.items() if k == 'content'}
        requests_response = [
            {
                'id': str(concept.get('id')),
                'name': concept.get('name'),
                'equivalence': 'WIDER' if target_vocabulary == 'SNOMED' else 'EQUAL',
            }
            for concept in requests_response.get('content', [])
        ]
        llm_output = json.dumps(requests_response[0])

    elif 'vs-query' in tools:
        results = db.similarity_search_with_score(source_concept, k=15)

        results = [
            {
                'id': result[0].metadata.get(target_id_column),
                'name': result[0].metadata.get(target_candidate_column) or result[0].page_content,
                'equivalence': 'WIDER' if target_vocabulary == 'SNOMED' else 'EQUAL',
            }
            for result in results
        ]
        llm_output = json.dumps(results[0])

    print('Baseline output: ', llm_output)

    if not llm_output:
        return 'Error', '', '', None, None, None, False, False, 0, 0, 0

    llm_intermediate_steps = ''

    # Extract concept ID, concept name, equivalence type from output
    llm_concept, llm_name, llm_equivalence = [], [], []
    try:
        llm_output_json = parse_json_markdown(llm_output)
        if type(llm_output_json) == dict:
            llm_output_json = [llm_output_json]
        for output in llm_output_json:
            llm_concept.append(output.get('id'))
            llm_name.append(output.get('name'))
            llm_equivalence.append(output.get('equivalence'))
    except:
        pass

    if type(llm_concept) == dict:
        llm_concept = list(llm_concept.values())[0]

    if type(llm_equivalence) == dict:
        llm_equivalence = list(llm_equivalence.values())[0]

    if not llm_concept:
        concept_matches = re.finditer(r'\b(\d{6,})\b', llm_output)
        for concept_match in concept_matches:
            if concept_match:
                for group_num in range(len(concept_match.groups()) + 1):
                    llm_concept.append(concept_match.group(group_num))

    if not llm_concept:
        return llm_output, '', llm_intermediate_steps, None, None, None, False, False, 0, 0, 0 #llm.prompt_tokens, llm.completion_tokens, llm.total_tokens

    if not llm_equivalence:
        equivalence_matches = re.finditer('EQUAL|EQUIVALENT|NARROWER|WIDER|INEXACT', llm_output, re.IGNORECASE)
        for equivalence_match in equivalence_matches:
            if equivalence_match:
                for group_num in range(len(equivalence_match.groups()) + 1):
                    llm_equivalence.append(equivalence_match.group(group_num).upper())

    partial_match, exact_match = False, False
    for concept, name, equivalence in zip(llm_concept, llm_name, llm_equivalence):
        if exact_match:
            break
        for reference in df_row[source_id_column]:
            # Check concept ID or name
            if name and name.lower() == reference.get(source_name_key).lower():
                compared_concept = reference.get(source_id_key)
            else:
                compared_concept = concept
            if compared_concept == reference.get(source_id_key):
                partial_match = True
                # Check equivalence type
                if equivalence == reference.get(source_equivalence_key):
                    partial_match, exact_match = False, True
                    break

    if exact_match:
        llm_concept = [compared_concept]

    return llm_output, '', llm_intermediate_steps, llm_concept, llm_name, llm_equivalence, partial_match, exact_match, 0, 0, 0 #llm.prompt_tokens, llm.completion_tokens, llm.total_tokens

In [None]:
def filter_references(df_cell, key, values):
    return [d for d in df_cell if d.get(key, None) in values]

def call_llm_agent(df_row, agent_executor, template, target_vocabulary, source_selection_column, source_description_column,
                   source_id_column, source_id_key, source_name_key, source_equivalence_key):
    if 'llm_output' in df_row and df_row.llm_output != 'Error':
        return (df_row.llm_output, df_row.llm_prompt, df_row.llm_intermediate_steps, df_row.llm_concept, df_row.llm_name, df_row.llm_equivalence,
                df_row.partial_match, df_row.exact_match, df_row.prompt_tokens, df_row.completion_tokens, df_row.total_tokens)

    # Set description
    source_description = ''
    if source_description_column and not pd.isna(df_row[source_description_column]):
        source_description = 'Description: {}'.format(str(df_row[source_description_column]).strip().capitalize())

    prompt = template.format(**{'target_vocabulary': target_vocabulary,
                                'source_concept': df_row[source_selection_column].strip().capitalize(),
                                'source_description': source_description,
                                })

    try:
        llm_output = agent_executor.invoke(prompt)
    except Exception as e:
        print('Exception', str(e))
        return 'Error', prompt, '', None, None, None, False, False, 0, 0, 0

    if not llm_output:
        return 'Error', prompt, '', None, None, None, False, False, 0, 0, 0

    llm_intermediate_steps = llm_output.get('intermediate_steps', '')
    llm_output = llm_output.get('output', '')

    # Extract concept ID, concept name, equivalence type from output
    llm_concept, llm_name, llm_equivalence = [], [], []
    try:
        llm_output_json = parse_json_markdown(llm_output)
        if type(llm_output_json) == dict:
            llm_output_json = [llm_output_json]
        for output in llm_output_json:
            llm_concept.append(output.get('id'))
            llm_name.append(output.get('name'))
            llm_equivalence.append(output.get('equivalence'))
    except:
        pass

    if type(llm_concept) == dict:
        llm_concept = list(llm_concept.values())[0]

    if type(llm_equivalence) == dict:
        llm_equivalence = list(llm_equivalence.values())[0]

    if not llm_concept:
        concept_matches = re.finditer(r'\b(\d{6,})\b', llm_output)
        for concept_match in concept_matches:
            if concept_match:
                for group_num in range(len(concept_match.groups()) + 1):
                    llm_concept.append(concept_match.group(group_num))

    if not llm_concept:
        return llm_output, prompt, llm_intermediate_steps, None, None, None, False, False, 0, 0, 0 #llm.prompt_tokens, llm.completion_tokens, llm.total_tokens

    if not llm_equivalence:
        equivalence_matches = re.finditer('EQUAL|EQUIVALENT|NARROWER|WIDER|INEXACT', llm_output, re.IGNORECASE)
        for equivalence_match in equivalence_matches:
            if equivalence_match:
                for group_num in range(len(equivalence_match.groups()) + 1):
                    llm_equivalence.append(equivalence_match.group(group_num).upper())

    partial_match, exact_match = False, False
    for concept, name, equivalence in zip(llm_concept, llm_name, llm_equivalence):
        if exact_match:
            break
        for reference in df_row[source_id_column]:
            # Check concept ID or name
            if name and name.lower() == reference.get(source_name_key).lower():
                compared_concept = reference.get(source_id_key)
            else:
                compared_concept = concept
            if compared_concept == reference.get(source_id_key):
                partial_match = True
                # Check equivalence type
                if equivalence == reference.get(source_equivalence_key):
                    partial_match, exact_match = False, True
                    break

    if exact_match:
        llm_concept = [compared_concept]

    return llm_output, prompt, llm_intermediate_steps, llm_concept, llm_name, llm_equivalence, partial_match, exact_match, 0, 0, 0 #llm.prompt_tokens, llm.completion_tokens, llm.total_tokens

In [None]:
def download_file(url, local_directory, new_filename=None):
    if not os.path.exists(local_directory):
        os.makedirs(local_directory)

    filename = new_filename or url.split('/')[-1]
    local_path = os.path.join(local_directory, filename)

    if not os.path.exists(local_path):
        with requests.get(url, stream=True) as r:
            with open(local_path, 'wb') as f:
                shutil.copyfileobj(r.raw, f)

    return local_path

def safe_literal_eval(x):
    try:
        return ast.literal_eval(x) if pd.notna(x) else None
    except (ValueError, SyntaxError):
        return []

def run_llm_agent_pipeline(
    llm_repo,
    llm_type,
    prompt_template,
    tools,
    data,
    target_vocabulary,
    source_concept_column,
    source_id_column,
    source_id_key,
    source_name_key,
    source_vocabulary_key,
    source_equivalence_key,
    column_names=[],
    column_converters={},
    list_columns=[],
    filter_column=None,
    source_description_column=None,
    sample_size=None,
    max_retries=3,
    llm_file=None,
    n_gpu_layers=-1,
    n_ctx=512,
    temperature=0.0,
    api_key=None,
    vs_path=None,
    vs_embedding=None,
    target_id_column=None,
    target_candidate_column=None,
    run_baseline=False,
):
    # LLM initialization
    if llm_type == 'llamacpp':
        # Download model
        hf_logging.set_verbosity_error()
        llm_path = hf_hub_download(repo_id=llm_repo, filename=llm_file, local_dir='./models', token=api_key)

        callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
        # Initialize local LlamaCpp LLM
        # TODO: check token usage with streaming and agent
        llm = LlamaCppWithUsage(
            model_path=llm_path,
            n_gpu_layers=n_gpu_layers,
            n_ctx=n_ctx, # Llama 3 = 8192
            temperature=temperature,
            callback_manager=callback_manager,
            verbose=True,
        )
        llm_repo = '{}-{}'.format(llm_repo.replace('/', '-'), llm_file)

    elif llm_type == 'openai':
        if not api_key:
            return 'Error: Invalid OpenAI API key.'
        # Initialize remote LLM
        llm = ChatOpenAI(openai_api_key=api_key, temperature=temperature, model=llm_repo)

    elif llm_type == 'hftransformers':
        tokenizer = AutoTokenizer.from_pretrained(llm_repo, token=api_key)

        model = AutoModelForCausalLM.from_pretrained(
            llm_repo,
            device_map='auto',
            offload_folder='offload',
            offload_state_dict=True,
            torch_dtype=torch.bfloat16,
            token=api_key,
        )

        pipe = pipeline('text-generation', model=model, tokenizer=tokenizer)

        # Initialize local HF Transformers LLM
        llm = HuggingFacePipeline(pipeline=pipe, verbose=True)

        llm_repo = '{}'.format(llm_repo.replace('/', '-'))

    elif llm_type == 'googlegenai':
        if not api_key:
            return 'Error: Invalid Google GenAI API key.'
        # Initialize remote LLM
        llm = ChatGoogleGenerativeAI(google_api_key=api_key, temperature=temperature, model=llm_repo,
                                     safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH})

    agent_tools = []
    db = None

    # OHDSI Athena Web API (docs)
    if 'athena-api' in tools:
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36',
        }
        requests_wrapper = RequestsWrapper(headers=headers, response_content_type='json')

        # Get API docs, according to target vocabulary
        raw_athena_api_spec = yaml.load(get_athena_openapi_docs(target_vocabulary), Loader=yaml.Loader)
        athena_api_spec = reduce_openapi_spec(raw_athena_api_spec)

        # Initialize API tools
        agent_tools.extend([
            _create_api_planner_tool(athena_api_spec, llm),
            _create_api_controller_tool(athena_api_spec, requests_wrapper, llm),
        ])

        # Set agent prompt
        agent_prompt = PromptTemplate(
            template=API_ORCHESTRATOR_PROMPT,
            input_variables=["input", "agent_scratchpad"],
            partial_variables={
                "tool_names": ", ".join([tool.name for tool in agent_tools]),
                "tool_descriptions": "\n".join(
                    [f"{tool.name}: {tool.description}" for tool in agent_tools]
                ),
            },
        )

    # Simplified Athena API request tool (request-only)
    elif 'athena-api-simple':
        # Set API information
        api_template = (
            "Useful for when you need to query the API for {name} concepts."
            "Whenever you need information about {description} you should ALWAYS use this."
            "Input MUST be the query term only."
            "Output is a json serialized list of dictionaries with keys `id` and `name`, with the ID's and names of the concepts. "
        )
        api_template = api_template.format(name=target_vocabulary, description='{} concepts'.format(target_vocabulary))

        # Initialize API tools
        api_with_sources_tool = APIRequestQAWithSourcesTool(
            name='{}_concepts'.format(target_vocabulary),
            description=api_template,
            llm=llm,
            url='https://athena.ohdsi.org/api/v1/concepts',
            headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36'},
            data_params={
                'pageSize': 15,
                'page': 1,
                'standardConcept': 'Standard',
                'vocabulary': 'SNOMED' if target_vocabulary == 'SNOMED' else ['RxNorm', 'RxNorm Extension'],
                'invalidReason': 'Valid',
            },
            response_key='content',
        )

        # Set agent prompt and tools
        agent_tools = [api_with_sources_tool]

        API_AGENT_PREFIX = (
            'You are an agent designed to perform tasks with the help of an API.'
            'You have access to tools for interacting with the API, and the input to the tools is the query to perform against the API.'
        )

        agent_prompt = ZeroShotAgent.create_prompt(agent_tools, prefix=API_AGENT_PREFIX)

    # Vector store querying tool
    elif 'vs-query' in tools:
        # Create vector store folder
        os.makedirs('./vectorstore', exist_ok=True)

        # Download VS data
        vs_path = download_file(vs_path, './data', new_filename='vectorstore.tar.gz')

        # Extract tar.gz file
        with tarfile.open(vs_path, 'r:gz') as tar:
            vs_folder = tar.getnames()[0]
            tar.extractall(path='./vectorstore')

        # Load embedding model
        embedding_model = SentenceTransformerEmbeddings(model_name=vs_embedding)

        # Load extracted vector file
        db = FAISS.load_local('./vectorstore/{}'.format(vs_folder), embedding_model, allow_dangerous_deserialization=True)

        # Check vector store
        print('Vector store loaded: {} documents'.format(db.index.ntotal))

        # Set vector store information
        vs_template = (
            "Useful for when you need to query the vector store for {name} concepts."
            "Whenever you need information about {description} you should ALWAYS use this."
            "Input MUST be the query term only."
            "Output is a json serialized list of dictionaries with keys `id` and `name`, with the ID's and names of the concepts. "
        )
        vs_template = vs_template.format(name=target_vocabulary, description='{} concepts'.format(target_vocabulary))

        # Initialize vector store tools
        qa_with_sources_tool = VectorStoreQAWithSourcesTool(
            name='{}_concepts'.format(target_vocabulary),
            description=vs_template,
            vectorstore=db,
            llm=llm,
            target_id_column=target_id_column,
            target_candidate_column=target_candidate_column,
        )

        # Set agent prompt and tools
        agent_tools = [qa_with_sources_tool]

        VECTORSTORE_AGENT_PREFIX = (
            'You are an agent designed to perform tasks with the help of a vector store.'
            'You have access to tools for interacting with the vector store, and the input to the tools is the query to perform against the vector store.'
        )

        agent_prompt = ZeroShotAgent.create_prompt(agent_tools, prefix=VECTORSTORE_AGENT_PREFIX)

    # Initialize agent with tools
    agent = ZeroShotAgent(
        llm_chain=LLMChain(llm=llm, prompt=agent_prompt, memory=None),
        allowed_tools=[tool.name for tool in agent_tools],
    )
    agent_executor = AgentExecutor.from_agent_and_tools(
        agent=agent,
        tools=agent_tools,
        callback_manager=None,
        verbose=True,
        return_intermediate_steps=True,
    )

    # Set instruction template
    template = PromptTemplate.from_template(prompt_template)

    # Download dataset
    dataset_path = download_file(data, './data', new_filename='dataset.csv')

    # Load dataframe
    data_df = pd.read_csv(dataset_path, sep='\t',
                          names=column_names, converters=column_converters, header=0)

    # Column to use for LLM selection (first column for VS querying)
    source_selection_column = source_concept_column[0]

    # Allow composite columns for querying
    if len(source_concept_column) > 1:
        data_df['-'.join(source_concept_column)] = data_df.apply(
            lambda row: ' '.join('' if pd.isna(row[col]) else str(row[col])
            for col in source_concept_column), axis=1)

    source_concept_column = '-'.join(source_concept_column)

    # Preprocess list columns
    for list_column in list_columns:
        data_df[list_column] = data_df[list_column].apply(safe_literal_eval)

    # Filter by target vocabulary
    if target_vocabulary == 'SNOMED':
        data_df = data_df[data_df.ID.str[:2] != '06']
    elif target_vocabulary == 'RxNorm':
        data_df = data_df[data_df.ID.str[:2] == '06']

    original_df_size = len(data_df)

    # Filter dataframe (non-empty rows, reference vocabulary, valid equivalence)
    if filter_column:
        data_df = data_df[~data_df[filter_column].isna()]
        data_df[source_id_column] = data_df[source_id_column].apply(lambda x: filter_references(x, source_equivalence_key, ['EQUAL', 'EQUIVALENT', 'NARROWER', 'WIDER', 'INEXACT']))
        if target_vocabulary == 'SNOMED':
            data_df[source_id_column] = data_df[source_id_column].apply(lambda x: filter_references(x, source_vocabulary_key, ['SNOMED']))
        elif target_vocabulary == 'RxNorm':
            data_df[source_id_column] = data_df[source_id_column].apply(lambda x: filter_references(x, source_vocabulary_key, ['RxNorm', 'RxNorm Extension']))
        # Keep only rows with non-empty reference mappings
        data_df = data_df[data_df[source_id_column].apply(lambda x: len(x) > 0)]

    filtered_df_size = len(data_df)

    # Take sample from dataset (if None, don't sample)
    if sample_size:
        data_df = data_df.sample(n=sample_size, random_state=0)
    else:
        sample_size = filtered_df_size

    if run_baseline:
        from collections import Counter
        equivalences = [d[source_equivalence_key] for sublist in data_df[source_id_column] for d in sublist]
        equivalences_frequencies = Counter(equivalences)
        print(equivalences_frequencies)

        data_df[['llm_output', 'llm_prompt', 'llm_intermediate_steps', 'llm_concept', 'llm_name', 'llm_equivalence', 'partial_match', 'exact_match',
                    'prompt_tokens', 'completion_tokens', 'total_tokens']] = data_df.apply(
                call_baseline, args=(tools, db, template, target_vocabulary, source_selection_column, source_description_column,
                                    source_id_column, source_id_key, source_name_key, source_equivalence_key, target_id_column, target_candidate_column), axis=1, result_type='expand')
    else:
        # Invoke agent for each row in dataframe (with retries)
        for _ in range(max_retries + 1):
            data_df[['llm_output', 'llm_prompt', 'llm_intermediate_steps', 'llm_concept', 'llm_name', 'llm_equivalence', 'partial_match', 'exact_match',
                    'prompt_tokens', 'completion_tokens', 'total_tokens']] = data_df.apply(
                call_llm_agent, args=(agent_executor, template, target_vocabulary, source_selection_column, source_description_column,
                                    source_id_column, source_id_key, source_name_key, source_equivalence_key), axis=1, result_type='expand')
            # Check if no row returned error
            if not data_df['llm_output'].isin(['Error']).any():
                break

    # Create output folder
    if not os.path.exists('./output'):
        os.makedirs('./output')

    # Remove auxiliary column (source concept)
    if source_selection_column != source_concept_column:
        data_df = data_df.drop(source_concept_column, axis=1)

    # Write result dataframe to file
    run_timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    tools = '-'.join(tools)
    data_df.to_csv('./output/agent-{}-{}-{}-{}-{}.csv'.format(
        target_vocabulary, source_concept_column, llm_repo, tools, run_timestamp), sep='\t', index=False)

    # Calculate match metrics
    partial_matches = len(data_df[data_df['partial_match'] == True])
    exact_matches = len(data_df[data_df['exact_match'] == True])
    prompt_tokens_sum = data_df['prompt_tokens'].sum()
    completion_tokens_sum = data_df['completion_tokens'].sum()
    total_tokens_sum = data_df['total_tokens'].sum()

    # Print results (run parameters and match metrics)
    run_results = [
        'Agent results - {}'.format(run_timestamp),
        '====================================================',
        'Parameters:',
        'Target vocabulary: {}'.format(target_vocabulary),
        'LLM type: {}'.format(llm_type),
        'LLM repo: {}'.format(llm_repo),
        'LLM name: {}'.format(llm_file),
        'Temperature: {}'.format(temperature),
        'Agent tools: {}'.format(tools),
        'Embedding model: {}'.format(vs_embedding),
        'Source concept column(s): {}'.format(source_concept_column),
        'Source description column: {}'.format(source_description_column),
        '====================================================',
        'Match metrics',
        'Original dataset size = {}'.format(original_df_size),
        'Filtered dataset size ({}) = {}'.format(filter_column, filtered_df_size),
        'Sampled dataset size = {}'.format(sample_size),
        'Prompt tokens sum = {}'.format(prompt_tokens_sum),
        'Completion tokens sum = {}'.format(completion_tokens_sum),
        'Total tokens sum = {}'.format(total_tokens_sum),
        'Partial matches (only concept) = {}/{} ({:.3f})'.format(
            partial_matches, sample_size, partial_matches / sample_size if sample_size != 0 else 0),
        'Exact matches (concept and equivalence) = {}/{} ({:.3f})'.format(
            exact_matches, sample_size, exact_matches / sample_size if sample_size != 0 else 0),
    ]

    run_results = '\n'.join(run_results)
    print(run_results)

    # Write match metrics to file
    with open('./output/agent-results-{}.txt'.format(run_timestamp), 'w') as f:
        f.write(run_results)

    # Unload model from memory
    llm = None
    embedding_model = None
    db = None

In [None]:
API_AGENT_PROMPT = '''You are an agent tasked with finding the closest matching term in the {target_vocabulary} terminology for a given source term.

To do this, follow these steps:

1. **Query the API**: Use the API endpoint to search for terms in the {target_vocabulary} terminology. You MUST call the API at least once.
2. **Analyze Results**: Review the results from the API. Since the source term comes from a different terminology, there will rarely be an exact match in {target_vocabulary}.
3. **Refine Search**: If the results do not match the source term adequately, modify the source term using synonyms, variations, or other changes. Query the API again with the new term.
4. **Repeat as Necessary**: Continue refining your queries and using the API until you find a suitable match.
5. **Obligatory Use of API**: You MUST find the match using the API. Do not generate a match without querying the API.
6. **Return the Result in JSON**: Once you find a suitable match, return the result in the following JSON format:
{{
    "id": "<numeric code">,
    "name": <"target name">,
    "equivalence": <"equivalence type">
}}

- id: The target's numeric ID.
- name: The selected target term's name.
- equivalence: MUST be one of the following values EQUAL, EQUIVALENT, WIDER, NARROWER or INEXACT.

Important considerations:
- The API is the only resource available. Is is both sufficient and necessary to use it for finding matches.
- Ensure the JSON values are copied exactly as returned by the API.
- The source term is "{source_concept}"
{source_description}
'''

In [None]:
VS_AGENT_PROMPT = '''You are an agent tasked with finding the closest matching term in the {target_vocabulary} terminology for a given source term.

To do this, follow these steps:

1. **Query the vector store**: Use the vector store to search for terms in the {target_vocabulary} terminology. You MUST query the vector store at least once.
2. **Analyze Results**: Review the results from the vector store. Since the source term comes from a different terminology, there will rarely be an exact match in {target_vocabulary}.
3. **Refine Search**: If the results do not match the source term adequately, modify the source term using synonyms, variations, or other changes. Query the vector store again with the new term.
4. **Repeat as Necessary**: Continue refining your queries and querying the vector store until you find a suitable match.
5. **Obligatory Use of Vector Store**: You MUST find the match using the vector store. Do not generate a match without querying the vector store.
6. **Return the Result in JSON**: Once you find a suitable match, return the result in the following JSON format:
{{
    "id": "<numeric code">,
    "name": <"target name">,
    "equivalence": <"equivalence type">
}}

- id: The target's numeric ID.
- name: The selected target term's name.
- equivalence: MUST be one of the following values EQUAL, EQUIVALENT, WIDER, NARROWER or INEXACT.

Important considerations:
- The vector store is the only resource available. Is is both sufficient and necessary to use it for finding matches.
- Ensure the JSON values are copied exactly as returned by the vector store.
- The source term is "{source_concept}"
{source_description}
'''

In [None]:
# Example usage
run_llm_agent_pipeline(llm_repo='gpt-4o-2024-05-13', # gpt-4o-2024-05-13, gemini-1.5-pro-latest or HuggingFace repo
                       llm_type='openai', # openai, googlegenai, llamacpp or hftransformers
                       prompt_template=API_AGENT_PROMPT, # API_AGENT_PROMPT or VS_AGENT_PROMPT
                       tools=['athena-api-simple'], # athena-api-simple or vs-query
                       data='<SIGTAP with annotated correspondences.csv>', # From reference dataset
                       target_vocabulary='SNOMED', # SNOMED or RxNorm
                       source_concept_column=['Name', 'Description'],
                       source_id_column='Annotations',
                       source_id_key='conceptId',
                       source_name_key='conceptName',
                       source_vocabulary_key='vocabulary_id',
                       source_equivalence_key='equivalence',
                       column_names=['ID', 'Name', 'Description', 'sourceCode', 'Annotations'],
                       column_converters={'ID': str, 'sourceCode': str},
                       list_columns=['Annotations'],
                       filter_column='Annotations',
                       source_description_column='Description', # None, if no description column
                       sample_size=50, # None, for no sampling
                       max_retries=1,
                       llm_file=None, # HuggingFace filename for LLM (for llm_type llamacpp and hftransformers)
                       n_gpu_layers=-1,
                       n_ctx=32768,
                       temperature=0.0,
                       api_key=None, # OpenAI, Google AI, or HuggingFace API key
                       vs_path=None, # URL for compressed FAISS vector store in tar.gz (vs-query tool)
                       vs_embedding=None, # HuggingFace repo for embedding model (vs-query tool)
                       target_id_column='concept_id',
                       target_candidate_column='concept_name',
                       run_baseline=False, # Whether to run in baseline mode (default=False)
                       )