# LLM Testing

Here we test the LLM and check that it works as wanted with the Ollama setup.

Before running this file it is needed to install Ollama and the LLaMA 3.1 model as explained in the `setup_instruction.md`

### First we do a simple call

To check that the LLM is correctly installed and we can query it with python.

In [3]:
# imports
import ollama
import json
from datetime import date

In [4]:
def simple_call(prompt):
    """
    Function to test a simple call to the LLM model using Ollama.

    Args:
        prompt (str): The prompt to send to the LLM.

    Returns:
        str: The response from the LLM.
    """
    response = ollama.chat(
        model="llama3.1:8b",
        messages=[
            {
                "role": "system",
                "content": "You are a helpful assistant."
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        options={
            "temperature": 0.2
        }
    )

    return response["message"]["content"]


In [5]:
# simple test to check if ollama and LLM calls work
prompt = "What color is a strawberry?"
response = simple_call(prompt)
print("Prompt:", prompt)
print("LLM Response:", response)

Prompt: What color is a strawberry?
LLM Response: A strawberry is typically red in color, although some varieties may have a slightly yellow or white tint to them. The exact shade of red can vary depending on the ripeness and type of strawberry!


### Next we try to implement the Q&A and chat capabilities

In [None]:
# functions I use for the Q&A and chat capabilities

def validate_analysis_request(analysis_request: dict) -> tuple[bool, str]:
    """
    Validate that the analysis request has correct format and allowed values.
    
    Args:
        analysis_request: Dictionary with analysis request fields
    
    Returns:
        tuple: (is_valid: bool, error_message: str)
    """
    # Check required fields exist
    required_fields = ["unclear", "time_frame", "date_from", "date_to", "area", "analysis_type"]
    for field in required_fields:
        if field not in analysis_request:
            return False, f"Missing required field: {field}"
    
    # Validate 'unclear' field
    if analysis_request["unclear"].lower() not in ["yes", "no"]:
        return False, f"Invalid 'unclear' value: must be 'yes' or 'no', got '{analysis_request['unclear']}'"
    
    # Validate 'time_frame' field
    if analysis_request["time_frame"].lower() not in ["day", "longer"]:
        return False, f"Invalid 'time_frame' value: must be 'day' or 'longer', got '{analysis_request['time_frame']}'"
    
    # Validate date format (YYYY-MM-DD)
    import re
    date_pattern = r'^\d{4}-\d{2}-\d{2}$'
    for date_field in ["date_from", "date_to"]:
        if not re.match(date_pattern, analysis_request[date_field]):
            return False, f"Invalid date format for '{date_field}': expected YYYY-MM-DD, got '{analysis_request[date_field]}'"
    
    # Validate 'area' field
    allowed_areas = ["ch", "de", "it", "at", "np", "no", "dk"]
    if analysis_request["area"].lower() not in allowed_areas:
        return False, f"Invalid 'area' value: must be one of {allowed_areas}, got '{analysis_request['area']}'"
    
    # Validate 'analysis_type' field
    allowed_types = ["spot", "intraday"]
    if analysis_request["analysis_type"].lower() not in allowed_types:
        return False, f"Invalid 'analysis_type' value: must be one of {allowed_types}, got '{analysis_request['analysis_type']}'"
    
    # Check date_from <= date_to
    from datetime import datetime
    try:
        date_from_obj = datetime.strptime(analysis_request["date_from"], "%Y-%m-%d")
        date_to_obj = datetime.strptime(analysis_request["date_to"], "%Y-%m-%d")
        if date_from_obj > date_to_obj:
            return False, f"date_from ({analysis_request['date_from']}) must be <= date_to ({analysis_request['date_to']})"
    except ValueError as e:
        return False, f"Invalid date value: {e}"
    
    return True, ""

def get_analysis_request():
    """
    Function that reads the analysis request from the user input.

    Returns:
        dict: A dictionary containing the analysis request details.
    """

    analysis_request_accepted = False

    while not analysis_request_accepted:
        # ask the user to input the request
        print("Write your analysis request below with the day/time frame, area and analysis type (spot/intraday):")
        user_input = input()
        print(user_input)

        analysis_prompt = f"""
            You are an expert energy data analyst. A user has provided you with the following input:

            \"\"\"{user_input}\"\"\"

            Please extract the following information from the user's input:
            1. If the user is interested in a specific day or a longer time period.
            2. The specific day or time period mentioned, if specific day set date_from and date_to as equal to that day. Use the format YYYY-MM-DD. If a longer time period is mentioned, extract the start and end dates.
            3. The country or region (area) the user is interested in, from the following list: (ch, de, it, at, np, no, dk)
            4. The type of analysis the user wants to perform, from the following list: (spot, intraday)

            If any information is missing or unclear, please indicate that in the respective fields.

            Provide the extracted information in a JSON format with the following structure:
            {{
                "unclear": "<yes/no>",
                "time_frame": "<day/longer>",
                "date_from": "<YYYY-MM-DD>",
                "date_to": "<YYYY-MM-DD>",
                "area": "<country_code>",
                "analysis_type": "<spot/intraday>"
            }}
            """



        response = ollama.chat(
            model="llama3.1:8b",
            messages=[
                {
                    "role": "system",
                    "content": "You are a helpful assistant, which only outputs JSON in the specified format."
                },
                {
                    "role": "user",
                    "content": analysis_prompt
                }
            ],
            options={
                "temperature": 0.2
            }
        )

        # Parse the JSON response
        try:
            analysis_request = json.loads(response["message"]["content"])
        except json.JSONDecodeError:
            analysis_request = {
                "analysis_type": "",
                "variables": [],
                "time_frame": "",
                "output_format": ""
            }
        
        # check the validity of the analysis request
        is_valid, error_message = validate_analysis_request(analysis_request)
        print("Validation result:", is_valid, error_message)
        if is_valid:
            analysis_request_accepted = True
        else:
            # If not valid, inform the user and loop again
            #print(f"Analysis request not valid: {error_message}. Please try again.")
            print("Could not retrieve all the necessary information from your input.")

    return analysis_request

In [None]:
analysis_request = get_analysis_request()
print("Final analysis request:", analysis_request)

Write your analysis request below with the day/time frame, area and analysis type (spot/intraday):
GIve me the analysis for the 10th of January
Validation result: False Missing required field: unclear
Analysis request not valid: Missing required field: unclear. Please try again.
Could not retrieve all the necessary information from your input.
Write your analysis request below with the day/time frame, area and analysis type (spot/intraday):
Give me the spot analysis for the 10th of January in the swiss market
Validation result: True 
Final analysis request: {'unclear': 'no', 'time_frame': 'day', 'date_from': '2023-01-10', 'date_to': '2023-01-10', 'area': 'ch', 'analysis_type': 'spot'}
