# Lab 10 - Agent

In [None]:
require 'json'
require 'sqlite3'
require 'time'
require 'openai'
require 'narray'
require 'dotenv'
require 'date'

Dotenv.load

$client = OpenAI::Client.new(access_token: ENV['OPENAI_API_KEY'])

## Our database schema, using SQLite

In [None]:
# Connect to SQLite (this will create a new database file if it doesn't exist)
db_filename = 'event_planning.db'

begin
  FileUtils.rm(db_filename) if File.exist?(db_filename)
  puts "File '#{db_filename}' deleted successfully."
rescue StandardError => e
  # Silently handle any errors, similar to the Python version
end

$db = SQLite3::Database.new('event_planning.db')
$db.results_as_hash = true

def initialize_database(db)
  # Create 'events' table
  $db.execute(<<-SQL)
    CREATE TABLE IF NOT EXISTS events (
      id INTEGER PRIMARY KEY AUTOINCREMENT,
      location TEXT NOT NULL,
      date TEXT NOT NULL
    )
  SQL

  $db.execute(<<-SQL)
    CREATE TABLE IF NOT EXISTS attendees (
      id INTEGER PRIMARY KEY AUTOINCREMENT,
      event_id INTEGER NOT NULL,
      full_name TEXT NOT NULL,
      FOREIGN KEY (event_id) REFERENCES events (id)
    )
  SQL
end

initialize_database($db)

## Calculates the cost of an OpenAI request

In [None]:
def cost(response)
  prices_per_million = {
    "gpt-4o" => { "input" => 5.00, "output" => 15.00 },
    "gpt-4o-2024-08-06" => { "input" => 2.50, "output" => 10.00 },
    "gpt-4o-2024-05-13" => { "input" => 5.00, "output" => 15.00 },
    "gpt-4o-mini" => { "input" => 0.15, "output" => 0.60 },
    "gpt-4o-mini-2024-07-18" => { "input" => 0.15, "output" => 0.60 },
    "gpt-3.5-turbo" => { "input" => 0.003, "output" => 0.006 },
    "davinci-002" => { "input" => 12.00, "output" => 12.00 },
    "babbage-002" => { "input" => 1.60, "output" => 1.60 },
    "text-embedding-3-small" => { "input" => 0.020, "output" => 0.020 },
    "text-embedding-3-large" => { "input" => 0.130, "output" => 0.130 },
    "ada-v2" => { "input" => 0.100, "output" => 0.100 }
  }

  model_version = response.model
  
  if prices_per_million.key?(model_version)
    input_price_per_million = prices_per_million[model_version]["input"]
    output_price_per_million = prices_per_million[model_version]["output"]
  else
    raise "Pricing information for model '#{model_version}' is not available."
  end

  # Get token usage
  prompt_tokens = response.usage.prompt_tokens
  completion_tokens = response.usage.completion_tokens

  input_cost = (prompt_tokens.to_f / 1_000_000) * input_price_per_million
  output_cost = (completion_tokens.to_f / 1_000_000) * output_price_per_million
  total_cost = input_cost + output_cost

  total_cost
end

## Helpers to process structured OpenAI responses

In [None]:
def tool_response(response)
  begin    
    message = response.dig("choices", 0, "message")
    function_call = message && message["function_call"]
    
    return nil unless function_call
    
    # Parse the arguments
    arguments_str = function_call["arguments"]
    result = JSON.parse(arguments_str)
    
    result
  rescue StandardError => e
    puts "Error in tool_response: #{e}"
    puts e.backtrace
    nil
  end
end


In [None]:
def chain_of_thought_response(response)
  begin
    function_call = response.dig("choices", 0, "message", "function_call")
    return nil unless function_call

    arguments_str = function_call["arguments"]
    JSON.parse(arguments_str)
  rescue StandardError => e
    puts "Error parsing chain of thought response: #{e}"
    puts "Response: #{response.inspect}"
    nil
  end
end


## OpenAI completion request methods

In [None]:
def complete_chain_of_thought(prompt, function)
  name = function["name"]
  description = function["description"]
  parameters = function["parameters"]
  
  # Make the API call  
  response = $client.chat(
    parameters: {
      model: "gpt-4",
      messages: [
        {
          role: "system",
          content: "You are a helpful event planning assistant. Always analyze and think step-by-step before responding with the final answer."
        },
        {
          role: "user",
          content: prompt
        }
      ],
      functions: [
        {
          name: name,
          description: description,
          parameters: parameters
        }
      ],
      function_call: { "name" => name }
    }
  )
  
  # Parse and return the response
  begin
    message = response.dig("choices", 0, "message")
    function_call = message && message["function_call"]
    return nil unless function_call

    JSON.parse(function_call["arguments"])
  rescue StandardError => e
    puts "Chain of thought error: #{e}"
    puts "Response was: #{response.inspect}"
    nil
  end
end


In [None]:
def complete_function(prompt, function, chain: false)
  if chain
    return complete_chain_of_thought(prompt, function)
  end
  
  response = $client.chat(
    parameters: {
      model: "gpt-4",
      messages: [
        {
          role: "system",
          content: "You are a helpful event planning assistant."
        },
        {
          role: "user",
          content: prompt
        }
      ],
      temperature: 0,
      functions: [function],
      function_call: { name: function["name"] }
    }
  )
  
  puts JSON.pretty_generate(response)
  
  result = tool_response(response)
  result
end

## Task Classification

In [None]:
def get_classification_prompt(query)
  <<~PROMPT
    # Instructions

    For an event management SaaS product, natural language queries need to be classified and routed to the appropriate branch. Classify the following user query:
    
    #{query}
  PROMPT
end

In [None]:
def get_classification_function(methods_enum)
  {
    "name" => "task_type",
    "description" => "Classify the task based on the user query.",
    "parameters" => {
      "type" => "object",
      "properties" => {
        "task_type" => {
          "type" => "string",
          "enum" => methods_enum
        }
      },
      "required" => ["task_type"],
      "additionalProperties" => false
    }
  }
end

In [None]:
def classify(methods_enum, query)
  prompt = get_classification_prompt(query)
  function = get_classification_function(methods_enum)
  classification = complete_function(prompt, function)
  
  puts "Debug - Raw classification: #{classification.inspect}"  # Add this debug line
  
  if classification && classification['task_type']
    task = classification['task_type']
    puts "#{query} => #{task}"
    task
  else
    puts "Classification failed"
    nil
  end
end

In [None]:
def date_from_string(datestr)
  begin
    Date.strptime(datestr, '%Y-%m-%d')
  rescue Date::Error
    nil
  end
end

## Create Event

In [None]:
def get_create_event_prompt(query)
  current_date = Time.now.strftime('%B %d, %Y')
  <<~PROMPT
    # Instructions

    Today's date is #{current_date}. For an event management SaaS product, extract the location and the date from the following user query. For dates that don't specify a year, always choose a date in the future:

    ## User Query

    #{query}
  PROMPT
end

def get_create_event_function
  {
    "name" => "event_details",
    "description" => "Extract the location and date for an event.",
    "parameters" => {
      "type" => "object",
      "properties" => {
        "location" => {
          "type" => "string"
        },
        "date" => {
          "type" => "string",
          "description" => "The date formatted in YYYY-MM-DD"
        }
      },
      "required" => ["location", "date"],
      "additionalProperties" => false
    }
  }
end

def create_event_sql(location, date)
  begin
    date_str = date.is_a?(Date) ? date.strftime('%Y-%m-%d') : date.to_s
    
    $db.execute(
      'INSERT INTO events (location, date) VALUES (?, ?)',
      [location, date_str]
    )
    puts "Event '#{location}' created successfully."
  rescue SQLite3::Exception => e
    puts "Error creating event: #{e}"
  end
end

def create_event(query, chain = false)
  args = complete_function(
    get_create_event_prompt(query),
    get_create_event_function,
    chain: chain
  )
  
  puts args
  if args
    location = args["location"]
    date = date_from_string(args["date"])
    create_event_sql(location, date)
  end
end

In [None]:
#create_event("Book the Ellison Lodge A for December 4th", chain: true)

## List Events

In [None]:
def get_list_events_prompt(query)
  current_date = Time.now.strftime('%B %d, %Y')
  <<~PROMPT
    # Instructions

    Today's date is #{current_date}. For an event management SaaS product, a user is asking to list upcoming events for which an optional date range can be provided. For dates that don't specify a year, always choose a date in the future. Extract the date range from the following user query:
    
    #{query}
  PROMPT
end

def get_list_events_function
  {
    "name" => "events_query",
    "description" => "Filter parameters to query an events SQL table based on dates.",
    "parameters" => {
      "type" => "object",
      "properties" => {
        "after_date" => {
          "type" => "string",
          "description" => "The date formatted in YYYY-MM-DD"
        },
        "before_date" => {
          "type" => "string",
          "description" => "The date formatted in YYYY-MM-DD"
        }
      },
      "required" => ["after_date", "before_date"],
      "additionalProperties" => false
    }
  }
end

def list_events_sql(after_date: nil, before_date: nil, output: true)
  query = 'SELECT id, location, date FROM events'
  params = []
  
  # Convert date objects to strings if they're not already
  after_date_str = after_date.is_a?(Date) ? after_date.strftime('%Y-%m-%d') : after_date
  before_date_str = before_date.is_a?(Date) ? before_date.strftime('%Y-%m-%d') : before_date
  
  # Add conditions for filtering by date
  if after_date_str && before_date_str
    query += ' WHERE date >= ? AND date <= ?'
    params.push(after_date_str, before_date_str)
  elsif after_date_str
    query += ' WHERE date >= ?'
    params.push(after_date_str)
  elsif before_date_str
    query += ' WHERE date <= ?'
    params.push(before_date_str)
  end
  
  # Execute the query with filters if any
  events = $db.execute(query, params)
  
  if events.any?
    if output
      events.each do |ev|
        puts "ID: #{ev[0]}, Location: #{ev[1]}, Date: #{ev[2]}"
      end
    end
    events
  else
    puts "No events found for the specified date range."
    nil
  end
end

def list_events(query, chain: false)
  args = complete_function(
    get_list_events_prompt(query),
    get_list_events_function,
    chain: chain
  )
  
  # puts args
  if args
    after_date = date_from_string(args["after_date"])
    before_date = date_from_string(args["before_date"])
    list_events_sql(after_date: after_date, before_date: before_date)
  end
end

In [None]:
list_events("highland", chain: true)

## Create Attendee

In [None]:
def get_events_enum
  events = list_events_sql(output: false)
  events_enum = events.map { |e| "#{e[1]} on #{e[2]}" }
  events_enum.push("other")
  events_enum
end

def get_event_id(event_str)  
  # Fallback spotted! No event found
  return nil if event_str == "other"  
  
  events = list_events_sql(output: false)
  events.each do |ev|
    curr_event_str = "#{ev[1]} on #{ev[2]}"
    return ev[0] if event_str == curr_event_str  # Found the event!
  end
  
  nil
end

In [None]:
def get_create_attendee_prompt(query)
  current_date = Time.now.strftime('%B %d, %Y')
  <<~PROMPT
    # Instructions

    Today's date is #{current_date}. For an event management SaaS product, extract the event and the attendee name from the following user query:
    
    #{query}
  PROMPT
end

def get_create_attendee_function(events_enum)
  {
    "name" => "attendee_rsvp_details",
    "description" => "Extract the event details and the attendee names.",
    "parameters" => {
      "type" => "object",
      "properties" => {
        "event" => {
          "type" => "string",
          "enum" => events_enum
        },
        "attendees" => {
          "type" => "array",
          "items" => {
            "type" => "string"
          }
        }
      },
      "required" => ["event", "attendees"],
      "additionalProperties" => false
    }
  }
end

def create_attendee_sql(event_id, full_name)
  begin
    $db.execute(
      'INSERT INTO attendees (event_id, full_name) VALUES (?, ?)',
      [event_id, full_name]
    )
    puts "Attendee '#{full_name}' added to event ID #{event_id}."
  rescue SQLite3::Exception => e
    puts "Error adding attendee: #{e}"
  end
end

def create_attendee(query, chain: false)
  events_enum = get_events_enum
  
  if events_enum&.any?
    function = get_create_attendee_function(events_enum)
    
    response = $client.chat(
      parameters: {
        model: "gpt-4",
        messages: [
          {
            role: "system",
            content: "You are a helpful event planning assistant."
          },
          {
            role: "user",
            content: get_create_attendee_prompt(query)
          }
        ],
        temperature: 0,
        functions: [function],
        function_call: { name: function["name"] }
      }
    )
    
    args = chain ? chain_of_thought_response(response) : tool_response(response)
    # puts "Debug - Attendee args: #{args.inspect}"
    
    if args
      event_id = get_event_id(args["event"])
      if event_id
        args["attendees"].each do |name|
          create_attendee_sql(event_id, name)
        end
      end
    end
  else
    puts "[create_attendee] No events found! Please create an event first"
  end
end

## List Attendees

In [None]:
def get_list_attendees_prompt(query)
  current_date = Time.now.strftime('%B %d, %Y')
  <<~PROMPT
    # Instructions

    Today's date is #{current_date}. For an event management SaaS product, extract the event from the following user query:
    
    #{query}
  PROMPT
end

def get_list_attendees_function(events_enum)
  {
    "name" => "event_query",
    "description" => "Identify the event details",
    "parameters" => {
      "type" => "object",
      "properties" => {
        "event" => {
          "type" => "string",
          "enum" => events_enum
        }
      },
      "required" => ["event"],
      "additionalProperties" => false
    }
  }
end

def list_attendees_sql(event_id)
  attendees = $db.execute(
    'SELECT id, full_name FROM attendees WHERE event_id = ?',
    [event_id]
  )
  
  if attendees.any?
    attendees.each do |attendee|
      puts "ID: #{attendee[0]}, Name: #{attendee[1]}"
    end
    attendees
  else
    puts "No attendees found for event ID #{event_id}."
    nil
  end
end

def list_attendees(query, chain: false)
  events_enum = get_events_enum
  
  if events_enum&.any?
    function = get_list_attendees_function(events_enum)
    
    response = $client.chat(
      parameters: {
        model: "gpt-4",
        messages: [
          {
            role: "system",
            content: "You are a helpful event planning assistant."
          },
          {
            role: "user",
            content: get_list_attendees_prompt(query)
          }
        ],
        temperature: 0,
        functions: [function],
        function_call: { name: function["name"] }
      }
    )
    
    args = chain ? chain_of_thought_response(response) : tool_response(response)
    
    if args
      event_id = get_event_id(args["event"])
      list_attendees_sql(event_id) if event_id
    end
  else
    puts "[list_attendees] No events found! Please create an event first"
  end
end

# Main Agent Loop

1. Accept a user query
2. Classify the query task (create event, list events, etc..)
3. Call appropriate task method
4. Repeat

In [None]:
def fallback(query, chain: false)
  puts "Fallback! #{query}"
end

def run(query, chain: true)
  methods = {
    "create_event" => method(:create_event),
    "list_events" => method(:list_events),
    "create_attendee" => method(:create_attendee),
    "list_attendees" => method(:list_attendees),
    "other" => method(:fallback)
  }
  
  methods_enum = methods.keys
  method_key = classify(methods_enum, query)
  # puts "Debug - Classification result: #{method_key.inspect}"  # Add this debug line
  
  if method_key && methods[method_key]
    method = methods[method_key]
    method.call(query, chain: chain)
  else
    puts "Error: Could not find appropriate method for #{query}"
  end
end

In [None]:
run("Book the Ellison Lodge A for December 4th")
run("Jane Doe is coming to the party on 12/4")
run("Pencil in highland park gazebo for june 11th")
run("Leah, Alice, and Fred are all coming to ellison in december")
run("Who is coming in december?")
run("What events are coming up?")

In [None]:
run("Steve's coming to highland park!")

In [None]:
run("highland list")

In [None]:
run("highland attendees")

In [None]:
list_events_sql(output: false)