﻿import os
import logging
from dotenv import load_dotenv
import chainlit as cl
from azure.identity import DefaultAzureCredential
from azure.ai.projects import AIProjectClient

try:
    from azure.ai.projects.models import MessageRole
except Exception:
    from azure.ai.agents.models import MessageRole

from cosmos_utils import save_message, get_user_history

# Load environment
load_dotenv()

logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
logger.setLevel(logging.WARNING)

# Azure AI Project
AIPROJECT_CONNECTION_STRING = os.getenv("AIPROJECT_CONNECTION_STRING")
AGENT_ID = os.getenv("AGENT_ID")

if not AIPROJECT_CONNECTION_STRING or not AGENT_ID:
    print("WARNING: AIPROJECT_CONNECTION_STRING or AGENT_ID not set. Agent calls will fail.")

credential = DefaultAzureCredential()
project_client = AIProjectClient.from_connection_string(
    conn_str=AIPROJECT_CONNECTION_STRING,
    credential=credential
)

# OAuth Callback
@cl.oauth_callback
def oauth_callback(provider_id: str, token: str, raw_user_data: dict, default_user: cl.User):
    print("DEBUG OAUTH CALLBACK:", provider_id, default_user.identifier)
    return default_user  # Allow any authenticated Azure AD user

# Chat start
@cl.on_chat_start
async def on_chat_start():
    user = cl.user_session.get("user")
    user_email = user.identifier if user else "anonymous"
    print(f"DEBUG: Logged-in user = {user_email}")
    cl.user_session.set("user_id", user_email)

    # Load user history
    history = get_user_history(user_email, limit=10)
    for msg in history:
        await cl.Message(author=msg["role"], content=msg["content"]).send()

    # New thread if not exists
    if not cl.user_session.get("thread_id"):
        thread = project_client.agents.create_thread()
        cl.user_session.set("thread_id", thread.id)
        print(f"New Thread ID for {user_email}: {thread.id}")

# On message
@cl.on_message
async def on_message(message: cl.Message):
    user_id = cl.user_session.get("user_id", "anonymous")
    thread_id = cl.user_session.get("thread_id")

    try:
        # Save user message
        save_message(user_id, "user", message.content)

        # Show thinking placeholder
        thinking = await cl.Message("thinking...", author="agent").send()

        # Send message to agent
        project_client.agents.create_message(
            thread_id=thread_id,
            role="user",
            content=message.content,
        )

        run = project_client.agents.create_and_process_run(
            thread_id=thread_id,
            agent_id=AGENT_ID
        )
        print(f"Run finished with status: {run.status} (user={user_id})")

        if getattr(run, "status", None) == "failed" and getattr(run, "last_error", None):
            print(f"Agent error: {run.last_error}")
            raise Exception(str(run.last_error))

        # Get response
        messages = project_client.agents.list_messages(thread_id)
        last_msg = messages.get_last_text_message_by_role(MessageRole.AGENT)
        if not last_msg:
            raise Exception("No response from the model.")

        reply_text = last_msg.text.value

        # Save and display
        save_message(user_id, "agent", reply_text)
        thinking.content = reply_text
        await thinking.update()

    except Exception as e:
        await cl.Message(content=f"Error: {str(e)}").send()


if __name__ == "__main__":
    pass
