<a href="https://colab.research.google.com/github/sohaiba/travel_agent_assistant/blob/main/Travel_Agent_Asssistant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def reset_thread():
    client.beta.threads.delete(st.session_state["thread"].id)
    st.session_state = {}

In [61]:
!pip install streamlit
!pip install plotly
!pip install --upgrade openai



In [62]:
import os
import json
from openai import OpenAI
import time
from google.colab import userdata

# API keys are stored in Google Colab's Secret Manager
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

client = OpenAI()

In [63]:
# Define the main function
def run_assistant():
  # Creating an assistant with specific instructions and tools
  assistant = client.beta.assistants.create(
      instructions="You are a helpful travel assistant that can write and execute code, and has access to a digital map to display information",
      model="gpt-3.5-turbo-1106",
      tools=[
          {
              "type": "function",
              "function": {
                  "name": "update_map",
                  "parameters": {
                      "type": "object",
                      "properties": {
                          "longitude": {"type": "number"},
                          "latitude": {"type": "number"},
                          "zoom": {"type": "integer"}
                      },
                      "required": ["longitude", "latitude", "zoom"]
                  }
              }
          },
          {
              "type": "function",
              "function": {
                  "name": "add_marker",
                  "parameters": {
                      "type": "object",
                      "properties": {
                          "longitude": {"type": "number"},
                          "latitude": {"type": "number"},
                          "label": { "type": "string"}
                      },
                  "required": ["longitude", "latitude", "label"]
                  }
              }
          }
      ]
  )
  return assistant.id

In [64]:
assistant_id = run_assistant()

In [65]:
# Define functions
def update_map(latitude, longitude, zoom):
    #Move the Plotly map to give coordinates
    st.session_state["map"] = {
        "latitude": latitude,
        "longitude": longitude,
        "zoom": zoom,
    }
    return "Map updated"

def add_marker(latitude, longitude, label):
    #Add markers on the Plotly map
    st.session_state["markers"] = {
        "latitude": latitude,
        "longitude": longitude,
        "label": label,
    }
    return "Marker added"

In [66]:
# Map available functions
available_functions = {
    "update_map": update_map,
    "add_marker": add_marker
}

In [67]:
import streamlit as st
#Store global variables in session and reuse through the app
def initialize_session_state():
    if st.session_state is None:
      st.session_state = {}

    #Initialize an empty conversation
    if "conversation" not in st.session_state:
      st.session_state["conversation"] = []

    #Initialize map coordinates
    if "map" not in st.session_state:
      st.session_state["map"] = {
        "latitude": 30.3753,
        "longitude": 69.3451,
        "zoom": 16,
    }

    #Prepare reference to OpenAI Objects
    if "assistant" not in st.session_state:
      st.session_state["assistant"] = client.beta.assistants.retrieve(assistant_id)
      st.session_state["thread"] = client.beta.threads.create()
      st.session_state["run"] = None

    if "markers" not in st.session_state:
      st.session_state["markers"] = None

In [68]:
initialize_session_state()
st.session_state

{'conversation': [('assistant',
   "In Paris, there are many iconic landmarks and attractions to visit, such as the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and Montmartre. You may also want to take a leisurely stroll along the Seine River or explore the charming neighborhoods and cafes. Let me know if you'd like more details about any of these suggestions or if there's anything specific you're interested in!"),
  ('user', 'Suggest what I should do in Paris?'),
  ('assistant',
   "We're all set to explore Paris! What would you like to do first?"),
  ('user', "Let's go to Paris")],
 'map': {'latitude': 48.8566, 'longitude': 2.3522, 'zoom': 12},
 'assistant': Assistant(id='asst_WQkoiHUNGIdzoIrLQcSovfOd', created_at=1704800855, description=None, file_ids=[], instructions='You are a helpful travel assistant that can write and execute code, and has access to a digital map to display information', metadata={}, model='gpt-3.5-turbo-1106', name=None, object='assistant', tools=[To

In [69]:
# Adding a user message to the thread
user_messsage = "Take me to Germany!"

def create_message(user_messsage):
  messages = client.beta.threads.messages.create(
        thread_id=st.session_state["thread"].id,
        role="user",
        content=user_messsage
  )
  return messages.id

message_id = create_message(user_messsage)

In [70]:
def create_run():
  st.session_state["run"] = client.beta.threads.runs.create(
      thread_id=st.session_state["thread"].id,
      assistant_id=st.session_state["assistant"].id
  )
create_run()

In [71]:
def poll_run_status():
  # Loop until the run completes or requires action
  while True:
    run = client.beta.threads.runs.retrieve(thread_id=st.session_state["thread"].id, run_id=st.session_state["run"].id)

    # Add run steps retrieval here
    run_steps = client.beta.threads.runs.steps.list(thread_id=st.session_state["thread"].id, run_id=st.session_state["run"].id)
    print("Run Steps:", run_steps)

    if run.status == "requires_action":
        tool_calls = run.required_action.submit_tool_outputs.tool_calls
        tool_outputs = []

        for tool_call in tool_calls:
            function_name = tool_call.function.name
            function_args = json.loads(tool_call.function.arguments)

            if function_name in available_functions:
                function_to_call = available_functions[function_name]
                output = function_to_call(**function_args)
                tool_outputs.append({
                    "tool_call_id": tool_call.id,
                    "output": output,
                })

        # Submit tool outputs and update the run
        client.beta.threads.runs.submit_tool_outputs(
            thread_id=st.session_state["thread"].id,
            run_id=st.session_state["run"].id,
            tool_outputs=tool_outputs
        )

    elif run.status == "completed":
        # List the messages to get the response
        messages = client.beta.threads.messages.list(thread_id=st.session_state["thread"].id)
        for message in messages.data:
            role_label = "User" if message.role == "user" else "Assistant"
            message_content = message.content[0].text.value
            print(f"{role_label}: {message_content}\n")
        break  # Exit the loop after processing the completed run

    elif run.status == "failed":
        print("Run failed.")
        break

    elif run.status in ["in_progress", "queued"]:
        print(f"Run is {run.status}. Waiting...")
        time.sleep(0.1)  # Wait for 0.1 seconds before checking again

    else:
        print(f"Unexpected status: {run.status}")
        break

poll_run_status()

Run Steps: SyncCursorPage[RunStep](data=[RunStep(id='step_zepIv8AHwHNKEwiYWOzOPjVP', assistant_id='asst_WQkoiHUNGIdzoIrLQcSovfOd', cancelled_at=None, completed_at=None, created_at=1704810381, expired_at=None, failed_at=None, last_error=None, metadata=None, object='thread.run.step', run_id='run_p5c6XN2Ap08SBlkGACQjbonX', status='in_progress', step_details=ToolCallsStepDetails(tool_calls=[], type='tool_calls'), thread_id='thread_zLKR1CQrHa3oZ1rKaV8Dwfjp', type='tool_calls', expires_at=1704810979)], object='list', first_id='step_zepIv8AHwHNKEwiYWOzOPjVP', last_id='step_zepIv8AHwHNKEwiYWOzOPjVP', has_more=False)
Run is in_progress. Waiting...
Run Steps: SyncCursorPage[RunStep](data=[RunStep(id='step_zepIv8AHwHNKEwiYWOzOPjVP', assistant_id='asst_WQkoiHUNGIdzoIrLQcSovfOd', cancelled_at=None, completed_at=None, created_at=1704810381, expired_at=None, failed_at=None, last_error=None, metadata=None, object='thread.run.step', run_id='run_p5c6XN2Ap08SBlkGACQjbonX', status='in_progress', step_deta

In [72]:
def add_chat_msg():
  st.session_state["conversation"] = [
      (m.role, m.content[0].text.value) for m in client.beta.threads.messages.list(st.session_state["thread"].id).data
  ]

add_chat_msg()

In [79]:
st.session_state["conversation"]

[('assistant',
  "We've arrived in Germany! Is there a specific city or landmark you'd like to visit while we're here? Let me know how I can assist you further!"),
 ('user', 'Take me to Germany!'),
 ('assistant',
  "In Paris, there are many iconic landmarks and attractions to visit, such as the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and Montmartre. You may also want to take a leisurely stroll along the Seine River or explore the charming neighborhoods and cafes. Let me know if you'd like more details about any of these suggestions or if there's anything specific you're interested in!"),
 ('user', 'Suggest what I should do in Paris?'),
 ('assistant',
  "We're all set to explore Paris! What would you like to do first?"),
 ('user', "Let's go to Paris")]

In [59]:
#Create UI for Assistant
import plotly.graph_objects as go

initialize_session_state()

def on_text_input():
  #Push message to conversation
  message_id = create_message(st.session_state["input_user_msg"])
  create_run()
  poll_run_status()
  add_chat_msg()


st.title("Wanderlust")
left_col, right_col = st.columns(2)

with left_col:
  st.subheader("Conversation")
  for role, message in st.session_state["conversation"]:
    with st.chat_message(role):
      st.write(message)

with right_col:
  fig = go.Figure(go.Scattermapbox(mode="markers",))
  fig.update_layout(
    mapbox = dict(
        accesstoken = userdata.get('MAPBOX_TOKEN'),
        center = go.layout.mapbox.Center(
            lat = st.session_state["map"]["latitude"],
            lon = st.session_state["map"]["longitude"],
        ),
        zoom = st.session_state["map"]["zoom"],
    ),
    margin = dict(l=0, r=0, t=0, b=0),
  )
  if st.session_state["markers"] is not None:
    fig.add_trace(
        go.Scattermapbox(
            mode="markers",
            marker=go.scattermapbox.Marker(
                size=24,
                color="red",
            ),
            lat = st.session_state["markers"]["lat"],
            lon =st.session_state["markers"]["lon"],
            label = st.session_state["markers"]["label"],
        )
    )
  st.plotly_chart(fig, config={"displayModeBar": False}, use_container_width=True, key="plotly")

st.chat_input(
    placeholder = "Ask your question here",
    key = "input_user_msg",
    on_submit = on_text_input,
)