In [29]:
import boto3, json, os
from pathlib import Path
from config import toolConfig
from aind_data_access_api.document_db import MetadataDbClient
from botocore.exceptions import ClientError

In [30]:
client = boto3.client("bedrock-runtime", region_name="us-west-2")
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
bedrock = boto3.client(
    service_name="bedrock-runtime",
    region_name = 'us-west-2'
)

In [31]:
def get_completion(prompt, system_prompt=None, prefill=None):
    
    messages = [{"role": "user", "content": [{"text": prompt}]}]
    
    inference_config = {
        "temperature": 0,
        "maxTokens": 4000
    }
    toolConfig = {
                  "tools": [
                      {
                      "toolSpec": {
                          "name": "print_summary",
                          "description": "Print summary of JSON File.",
                          "inputSchema": {
                              "json": {
                                  "type": "object",
                                  "properties": {
                                      "filter": {
                                          "type": "string",
                                          "description": "A MongoDB query to pass to the function"
                                      }
                                  },
                                  "required": ["filter"]
                              }
                          }
                      }
                      }
                  ]
                }

    converse_api_params = {
        "modelId": model_id,
        "messages" : messages,
        "inferenceConfig": inference_config,
        "toolConfig": toolConfig
    }
    
    if system_prompt:
        converse_api_params["system"] = [{"text": system_prompt}]
    if prefill:
        messages.append({"role": "assistant", "content": [{"text": prefill}]})
        print(prefill)
    try:
        response = bedrock.converse(**converse_api_params)
        print(response)
        
        response_message = response['output']['message']
        
        response_content_blocks = response_message['content']
        
        #Assistant reply including tool use 
        messages.append({"role": "assistant", "content": response_content_blocks})
        
        for content_block in response_content_blocks:
            if 'toolUse' in content_block:
                print("Stop Reason:", response['stopReason'])
                
                tool_use = response_content_blocks[-1]
                tool_id = tool_use['toolUse']['toolUseId']
                tool_name = tool_use['toolUse']['name']
                tool_inputs = tool_use['toolUse']['input']

                #tool_use_block = content_block['toolUse']
                #tool_use_name = tool_use_block['name']
                
                print(f"Using tool {tool_name}")
                
                if tool_name == 'one_doc_retrieval':
                    filter_query_s = tool_inputs['filter'] # filter query stored as a string instead of dictionary
                    filter_query = json.loads(filter_query_s)
                    retrieved_info = one_doc_retrieval(filter_query) #retrieved info type, dictionary
                    
                    if type(retrieved_info) == list:
                        retrieved_info = {item['_id']:item for item in retrieved_info}

                        
                    
                    print(type(retrieved_info))
                    
                    tool_response = {
                                        "role": "user",
                                        "content": [
                                            {
                                                "toolResult": {
                                                    "toolUseId": tool_id,
                                                    "content": [
                                                        {
                                                            "json": retrieved_info
                                                            }
                                                    ],
                                                    'status':'success'
                                                }
                                            }
                                        ]
                                    }
                    
                    messages.append(tool_response)
                    
                    converse_api_params = {
                                                "modelId": model_id,
                                                "messages": messages,
                                                "inferenceConfig": inference_config,
                                                "toolConfig": toolConfig 
                                            }

                    final_response = bedrock.converse(**converse_api_params) 
                    print(final_response)
                    final_response_text = final_response['output']['message']['content'][0]['text']
                    print(final_response_text)
                    
                    #eturn messages
                    
                    #return retrieved_info
                    
                #return messages
                
        
        #return response_message
        #return messages

        
    except ClientError as err:
        message = err.response['Error']['Message']
        print(f"A client error occured: {message}")

In [19]:
f = open('procedures_schema (1).json')
procedures_schema = json.load(f)

In [20]:
f = open('../ref/subject_609281_metadata.json')
sample_metadata= json.load(f)

In [26]:
user_question = "What procedures occured in the sample metadata?"

In [27]:
prompt_template = f"""
    I will provide you with a schema that contains information about the accepted inputs of variable names in a JSON file.
    The schema is provided in a specified format and each file corresponds to a different section of an experiment.
    Procedure schema: {procedures_schema}
    
    I provide you with a sample, filled out metadata schema. It contains the fields as provided in the metadata schema, and may contain missing information. 
    Sample metadata: {sample_metadata}
    
    Your task is to read the user's question, which will adhere to certain guidelines or formats. 
    You maybe prompted to create a NOSQL MongoDB query that parses through a document structured like the sample metadata.
    You maybe prompted to determine missing information in the sample metadata.
    
    Here are some examples:
    Input: Give me the query to find subject's whose breeding group is Chat-IRES-Cre_Jax006410
    Output: "subject.breeding_info.breeding_group": "Chat-IRES-Cre_Jax006410"
    
    Note: You have to just return the query, nothing else. Provide the query in curly brackets, appropirately place quotation marks. Remove unnecessary symbols like slashes.
    
    Input: {user_question}
    """
PREFILL = '<query>'

In [28]:
print(get_completion(prompt_template))

{'ResponseMetadata': {'RequestId': '4099accc-a127-4948-989e-e7f1b84cd484', 'HTTPStatusCode': 200, 'HTTPHeaders': {'date': 'Mon, 19 Aug 2024 22:25:21 GMT', 'content-type': 'application/json', 'content-length': '320', 'connection': 'keep-alive', 'x-amzn-requestid': '4099accc-a127-4948-989e-e7f1b84cd484'}, 'RetryAttempts': 0}, 'output': {'message': {'role': 'assistant', 'content': [{'text': "The sample metadata does not contain any information about procedures performed on the subject. The 'procedures' field is set to None."}]}}, 'stopReason': 'end_turn', 'usage': {'inputTokens': 1833, 'outputTokens': 29, 'totalTokens': 1862}, 'metrics': {'latencyMs': 1295}}
None
