In [None]:
!pip install langchain langchain-openai

In [10]:
import dotenv
dotenv.load_dotenv()

True

In [None]:
import os
import json
from typing import Dict, Any
import requests
from langchain.tools import BaseTool
from langchain_core.utils.function_calling import format_tool_to_openai_function
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

class AmadeusFlightOffersTool(BaseTool):
    """Tool to retrieve flight offers from Amadeus API."""
    
    name: str = "amadeus_flight_offers"
    description: str = "Retrieves flight offers from the Amadeus API."
    
    @classmethod
    def extract_parameters_with_llm(
        cls, 
        query: str, 
        function_spec: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Extract API parameters using LLM based on function specification.
        """
        # Prepare a detailed description of parameters
        parameters_details = "\n\t".join([
            f"- {param}: {cls._get_parameter_description(param)} "
            f"(Type: {details.get('type', 'unknown')}, "
            f"Required: {'Yes' if param in function_spec['parameters'].get('required', []) else 'No'})"
            for param, details in function_spec['parameters']['properties'].items()
        ])

        # Create the system prompt
        system_prompt = f"""
        You are an expert at extracting structured parameters for an API flight search function.

        API Function Specification Details:
        Name: {function_spec.get('name', 'Unknown')}
        Description: {function_spec.get('description', 'No description')}

        Parameters:
        {parameters_details}

        Extraction Guidelines:
        1. Carefully analyze the user query to extract values for each parameter
        2. Match parameters exactly as specified in the function specification
        3. Use exact IATA codes for locations
        4. Use YYYY-MM-DD format for dates
        5. By default, assume 1 adult traveler and a maximum of 5 flight offers
        
        Output Instructions:
        - Return a valid JSON object with extracted parameters
        - Only include parameters you can confidently extract
        - Ensure type compatibility
        - If unsure about a parameter, do not include it
        """

        print(f"System Prompt:\n{system_prompt}")

        # Create the prompt template
        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", "{query}")
        ])

        # Initialize the LLM
        llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
        
        # Create the chain
        chain = prompt | llm
        
        # Generate the response
        response = chain.invoke({
            "query": query
        })
        
        # Extract and parse the parameters
        try:
            # Try to parse the LLM response as JSON
            extracted_params = json.loads(response.content)
            
            # Validate required parameters
            required_params = function_spec['parameters'].get('required', [])
            for param in required_params:
                if param not in extracted_params:
                    raise ValueError(f"Missing required parameter: {param}")
            
            return extracted_params
        
        except json.JSONDecodeError:
            # Fallback to minimal defaults
            # TODO RATTAN - Add more sophisticated fallbacks
            return {
                "adults": 1,
                "max": 5
            }
    
    @staticmethod
    def _get_parameter_description(param: str) -> str:
        """
        Provide detailed descriptions for each parameter.
        """
        descriptions = {
            'originLocationCode': 'City/airport IATA code from which the traveler will depart (e.g., BOS for Boston)',
            'destinationLocationCode': 'City/airport IATA code to which the traveler is going (e.g., PAR for Paris)',
            'departureDate': 'Date of departure in ISO 8601 YYYY-MM-DD format (e.g., 2017-12-25)',
            'adults': 'Number of adult travelers (age 12 or older)',
            'max': 'Maximum number of flight offers to return (must be >= 1, default 250)'
        }
        return descriptions.get(param, 'No description available')
    
    def _run(
        self, 
        originLocationCode: str, 
        destinationLocationCode: str, 
        departureDate: str, 
        adults: int, 
        max: int = 5
    ) -> Dict[str, Any]:
        """
        Execute the API call to retrieve flight offers.
        """
        # Retrieve API credentials from environment variables
        client_id = os.getenv('AMADEUS_CLIENT_ID')
        client_secret = os.getenv('AMADEUS_CLIENT_SECRET')
        
        if not client_id or not client_secret:
            raise ValueError("Amadeus API credentials not found. Set AMADEUS_CLIENT_ID and AMADEUS_CLIENT_SECRET.")
        
        # First, get an access token
        token_url = 'https://test.api.amadeus.com/v1/security/oauth2/token'
        token_data = {
            'grant_type': 'client_credentials',
            'client_id': client_id,
            'client_secret': client_secret
        }
        
        try:
            # Get access token
            token_response = requests.post(token_url, data=token_data)
            token_response.raise_for_status()
            access_token = token_response.json()['access_token']
            
            # Prepare API call
            api_url = 'https://test.api.amadeus.com/v2/shopping/flight-offers'
            headers = {
                'Authorization': f'Bearer {access_token}'
            }
            
            params = {
                'originLocationCode': originLocationCode,
                'destinationLocationCode': destinationLocationCode,
                'departureDate': departureDate,
                'adults': adults,
                'max': max
            }
            
            # Make the API call
            response = requests.get(api_url, headers=headers, params=params)
            response.raise_for_status()
            
            return response.json()
        
        except requests.RequestException as e:
            return {"error": str(e)}
    
    def _arun(self, *args, **kwargs):
        """
        Async run method - not implemented for this tool.
        """
        raise NotImplementedError("Async run is not supported for this tool.")

In [40]:
flight_tool = AmadeusFlightOffersTool()
    
openai_function = format_tool_to_openai_function(flight_tool)
    
query = "Book me a flight from New York to Delhi on 30th November 2024?"
    
params = AmadeusFlightOffersTool.extract_parameters_with_llm(
    query, 
    openai_function
)
    
print("Extracted Parameters:")
print(json.dumps(params, indent=2))
    
try:
    results = flight_tool._run(**params)
    print("\nFlight Offers:")
    print(json.dumps(results, indent=2))
except Exception as e:
    print(f"Error: {e}")

System Prompt:

        You are an expert at extracting structured parameters for an API flight search function.

        API Function Specification Details:
        Name: amadeus_flight_offers
        Description: Retrieves flight offers from the Amadeus API.

        Parameters:
        - originLocationCode: City/airport IATA code from which the traveler will depart (e.g., BOS for Boston) (Type: string, Required: Yes)
		- destinationLocationCode: City/airport IATA code to which the traveler is going (e.g., PAR for Paris) (Type: string, Required: Yes)
		- departureDate: Date of departure in ISO 8601 YYYY-MM-DD format (e.g., 2017-12-25) (Type: string, Required: Yes)
		- adults: Number of adult travelers (age 12 or older) (Type: integer, Required: Yes)
		- max: Maximum number of flight offers to return (must be >= 1, default 250) (Type: integer, Required: No)

        Extraction Guidelines:
        1. Carefully analyze the user query to extract values for each parameter
        2. Match