# Initializing Claude

## Imports

In [1]:
import boto3, json, os
from pathlib import Path
from config import toolConfig
from aind_data_access_api.document_db import MetadataDbClient
from aind_data_access_api.document_db_ssh import DocumentDbSSHClient, DocumentDbSSHCredentials
from botocore.exceptions import ClientError

## Connecting to Bedrock

In [2]:
client = boto3.client("bedrock-runtime", region_name="us-west-2")
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
bedrock = boto3.client(
    service_name="bedrock-runtime",
    region_name = 'us-west-2'
)

# Loading metadata schema

In [3]:
folder = Path("../ref")
for name in os.listdir(folder):
    if ".json" in name:
        file = open(f'{folder}/{name}')
        name = name[:-5]
        globals()[name] = json.load(file)

In [4]:
schema_types = [rig_schema, procedures_schema, acquisition_schema, instrument_schema, session_schema, subject_schema, data_description_schema, processing_schema, subject_609281_metadata, metadata_schema]

In [5]:
metadata_schema = schema_types.pop()

In [6]:
sample_metadata = schema_types.pop()

# Connecting to DocDB API + Implementing tool use

In [7]:
API_GATEWAY_HOST = "api.allenneuraldynamics.org"
DATABASE = "metadata_index"
COLLECTION = "data_assets"

docdb_api_client = MetadataDbClient(
   host=API_GATEWAY_HOST,
   database=DATABASE,
   collection=COLLECTION,
)

In [8]:
def doc_retrieval(filter_query):
    '''
    Retrieves one document from DocDB. Ideal for queries for a specific document
    
    :param filter_query: MongoDB Query
    :return: JSON File
    '''
    limit = 1000
    paginate_batch_size = 100
    response = docdb_api_client.retrieve_docdb_records(
       filter_query=filter_query,
       limit=limit,
       paginate_batch_size=paginate_batch_size
    )
    return(response)

In [9]:
def projection_retrieval(filter_query, field_name_list):
    '''
    Retrieves one document from DocDB. Ideal for queries for a specific document
    
    :param filter_query: MongoDB Query
    :param field_name_list: List of field names to be retrieved from the projection
    :return: JSON File
    '''
    credentials = DocumentDbSSHCredentials()
    with DocumentDbSSHClient(credentials=credentials) as doc_db_client:
        filter = filter_query
        projection = {"subject.subject_id" : 1}
        for field_name in field_name_list:
            projection[field_name] = 1
        #count = doc_db_client.collection.count_documents(filter)
        response = list(doc_db_client.collection.find(filter=filter, projection=projection))        
    return response

# Implementing tool use

In [24]:
system_prompt = "You are a neuroscientist with extensive knowledge about processes involves in neuroscience research. You are also an expert in crafting NoSQL queries for MongoDB. You must only do document retrieval with the available tool if specified." 

In [25]:
def get_completion(prompt, system_prompt=system_prompt, prefill=None):
    
    messages = [{"role": "user", "content": [{"text": prompt}]}]
    
    inference_config = {
        "temperature": 0,
        "maxTokens": 4000
    }
    toolConfig = {
                  "tools": [
                      {
                      "toolSpec": {
                          "name": "doc_retrieval",
                          "description": "Retrieve one document from docDB.",
                          "inputSchema": {
                              "json": {
                                  "type": "object",
                                  "properties": {
                                      "filter": {
                                          "type": "string",
                                          "description": "A MongoDB query to pass to the function"
                                      }
                                  },
                                  "required": ["filter"]
                              }
                          }
                      }
                      },
                      {
                      "toolSpec": {
                          "name": "projection_retrieval",
                          "description": "Retrieve multiple documents from docDB with only specific field information.",
                          "inputSchema": {
                              "json": {
                                  "type": "object",
                                  "properties": {
                                      "filter": {
                                          "type": "string",
                                          "description": "A MongoDB query to pass to the function"
                                      },
                                      "fieldNameList": {
                                          "type": "string",
                                          "description": "A MongoDB query to pass to the function"
                                      }
                                  },
                                  "required": ["filter"]
                              }
                          }
                      }
                      }
                  ]
                }

    converse_api_params = {
        "modelId": model_id,
        "messages" : messages,
        "inferenceConfig": inference_config,
        "toolConfig": toolConfig
    }
    
    if system_prompt:
        converse_api_params["system"] = [{"text": system_prompt}]
    if prefill:
        messages.append({"role": "assistant", "content": [{"text": prefill}]})
        print(prefill)
    try:
        response = bedrock.converse(**converse_api_params)
        print(response)
        
        response_message = response['output']['message']
        
        response_content_blocks = response_message['content']
        
        #Assistant reply including tool use 
        messages.append({"role": "assistant", "content": response_content_blocks})
        
        for content_block in response_content_blocks:
            if 'toolUse' in content_block:
                print("Stop Reason:", response['stopReason'])
                
                tool_use = response_content_blocks[-1]
                tool_id = tool_use['toolUse']['toolUseId']
                tool_name = tool_use['toolUse']['name']
                tool_inputs = tool_use['toolUse']['input']

                #tool_use_block = content_block['toolUse']
                #tool_use_name = tool_use_block['name']
                
                print(f"Using tool {tool_name}")
                
                if tool_name == 'doc_retrieval':
                    filter_query_s = tool_inputs['filter'] # filter query stored as a string instead of dictionary
                    filter_query = json.loads(filter_query_s)
                    retrieved_info = doc_retrieval(filter_query) #retrieved info type, dictionary
                    
                    if type(retrieved_info) == list:
                        retrieved_info = {item['_id']:item for item in retrieved_info}

                        
                    
                    print(type(retrieved_info))
                    
                    tool_response = {
                                        "role": "user",
                                        "content": [
                                            {
                                                "toolResult": {
                                                    "toolUseId": tool_id,
                                                    "content": [
                                                        {
                                                            "json": retrieved_info
                                                            }
                                                    ],
                                                    'status':'success'
                                                }
                                            }
                                        ]
                                    }
                    
                    messages.append(tool_response)
                    
                    converse_api_params = {
                                                "modelId": model_id,
                                                "messages": messages,
                                                "inferenceConfig": inference_config,
                                                "toolConfig": toolConfig 
                                            }

                    final_response = bedrock.converse(**converse_api_params) 
                    final_response_text = final_response['output']['message']['content'][0]['text']
                    return(final_response_text)
                    
                    #eturn messages
                    
                    #return retrieved_info
                    
                #return messages
                
        
        #return response_message
        #return messages

        
    except ClientError as err:
        message = err.response['Error']['Message']
        print(f"A client error occured: {message}")

In [30]:
user_question = "Give me an example query. What's the query used? Can you infer how the experiment was carried out? " 

In [31]:
prompt_template = f"""
    I will provide you with a schema that contains information about the accepted inputs of variable names in a JSON file.
    The schema is provided in a specified format and each file corresponds to a different section of an experiment.
    List of schemas: {schema_types}
    
    The Metadata schema shows how the different schema types are arranged, and how to appropriatel access them. 
    For example, in order to access something within the procedures field, you will have to start the query with "procedures."
    Metadata schema: {metadata_schema}
    
    I provide you with a sample, filled out metadata schema. It may contain missing information but serves as a reference to what a metadata file looks like. 
    You can use it as a guide to better structure your queries. Word like "false" and "null" has to be in quotes.
    Sample metadata: {sample_metadata}
    
    Your task is to read the user's question, which will adhere to certain guidelines or formats. 
    You maybe prompted to create a NOSQL MongoDB query that parses through a document structured like the sample metadata.
    You maybe prompted to determine missing information in the sample metadata.
    You maybe prompted to retrieve information from an external database, the information will be stored in json files. 
    
    Here are some examples:
    Input: Give me the query to find subject's whose breeding group is Chat-IRES-Cre_Jax006410
    Output: "subject.breeding_info.breeding_group": "Chat-IRES-Cre_Jax006410"
    
    Note: Provide the query in curly brackets, appropirately place quotation marks.
    Do not call tool use if not specified. 
    If there are instructions provided after document retrieval, apply the instructions on the returned output (retrieved document).
    
    Along with the answer, tell me a step by step process of what you are reasoning in tags. 
    Including how the query was formulated.
    
    Input: {user_question}
    """

In [32]:
get_completion(prompt_template)

{'ResponseMetadata': {'RequestId': '3ff2ea54-072e-4c00-8a4e-dddec4353d79', 'HTTPStatusCode': 200, 'HTTPHeaders': {'date': 'Wed, 21 Aug 2024 21:57:12 GMT', 'content-type': 'application/json', 'content-length': '1676', 'connection': 'keep-alive', 'x-amzn-requestid': '3ff2ea54-072e-4c00-8a4e-dddec4353d79'}, 'RetryAttempts': 1}, 'output': {'message': {'role': 'assistant', 'content': [{'text': '<reasoning>\nStep 1: Understand the request - You have asked for an example query and to infer how the experiment was carried out based on the query.\nStep 2: Analyze the metadata schema - The metadata schema contains various fields related to the experiment, such as subject information, data description, procedures, session details, rig setup, processing steps, acquisition details, and instrument information.\nStep 3: Formulate an example query - Let\'s query for the subject\'s genotype and species.\n</reasoning>\n\nExample query:\n{\n  "subject.genotype": "Slc17a6-IRES-Cre/wt",\n  "subject.species.

## Using tool to formulate a response

In [37]:
messages_list = []

In [38]:
def get_completion(prompt, system_prompt=None, prefill=None):
    inference_config = {
        "temperature": 0,
        "maxTokens": 200
    }
    converse_api_params = {
        "modelId": model_id,
        "messages": [{"role": "user", "content": [{"text": prompt}]}],
        "inferenceConfig": inference_config,
        "toolConfig": toolConfig
    }
    
    message_list.append(f"User: {prompt}\n")
    
    if system_prompt:
        converse_api_params["system"] = [{"text": system_prompt}]
    if prefill:
        converse_api_params["messages"].append({"role": "assistant", "content": [{"text": prefill}]})
        print(prefill)
    try:
        response = bedrock.converse(**converse_api_params)
        
        print("Stop Reason:", response['stopReason'])
        
        response_message = response['output']['message']
        print(response_message)
        
        response_content_blocks = response_message['content']
        for content_block in response_content_blocks:
            if 'toolUse' in content_block:
                tool_use_block = content_block['toolUse']
                tool_use_name = tool_use_block['name']
                
                print(f"Using tool {tool_use_name}")
                
                if tool_use_name == 'docdb_retrieval':
                    filter_query_s = tool_use_block['input']['filter'] # filter query stored as a string instead of dictionary
                    filter_query = json.loads(filter_query_s)
                    retrieved_info = docdb_retrieval(filter_query)
                    message_list.append(f"Assistant: {retrieved_info}\n")
                    return retrieved_info
        return response_message

    except ClientError as err:
        message = err.response['Error']['Message']
        print(f"A client error occured: {message}")

In [39]:
tool_response = {
    "role": "user",
    "content": [
        {
            "toolResult": {
                "toolUseId": tool_id,
                "content": [
                    {
                        "text": wiki_result
                        }
                ]
            }
        }
    ]
}

NameError: name 'tool_id' is not defined

### Implementing tool use for procedure queries only

In [486]:
prompt_template = f"""
    I will provide you with a schema that contains information about the accepted inputs of variable names in a JSON file.
    The schema is provided in a specified format and each file corresponds to a different section of an experiment.
    Schema: {metadata_schema}
    
    I provide you with a sample, filled out metadata schema. It contains the fields as provided in the metadata schema, and may contain missing information. 
    Sample metadata: {sample_metadata}
    
    Your task is to read the user's question, which will adhere to certain guidelines or formats. The questions will only be about the subject section in the sample metadata.
    You maybe prompted to create a NOSQL MongoDB query that parses through a document structured like the sample metadata.
    You maybe prompted to determine missing information in the sample metadata.
    You maybe prompted to retrieve information from an external database. 
    
    Here are some examples:
    Input: Give me the query to find subject's whose breeding group is Chat-IRES-Cre_Jax006410
    Output: "subject.breeding_info.breeding_group": "Chat-IRES-Cre_Jax006410"
    
    Note: You have to just return the query, nothing else. Provide the query in curly brackets, appropirately place quotation marks. 
    Do not call tool use if not specified.
    
    Input: {user_question}
    """

In [487]:
user_question = "I want all assets whose name includes the string smartspim"

In [488]:
print(get_completion(prompt_template))

A client error occured: The model returned the following errors: Input is too long for requested model.
None


In [531]:
type(docdb_retrieval({"subject.genotype": "Gad2-IRES-Cre/Gad2-IRES-Cre"}))

dict