In [None]:
# =============================
# Title: Training Intent Classifier
# =============================

# 1️⃣ Import Libraries
import os
import torch
import torch.nn as nn
import torch.optim as optim
from sentence_transformers import SentenceTransformer
from agents.intent_agent import IntentClassifier, IntentAgent, DEFAULT_INTENTS

# 2️⃣ Setup device and embedding model
device = "cuda" if torch.cuda.is_available() else "cpu"
embed_model_name = "all-MiniLM-L6-v2"
embedder = SentenceTransformer(embed_model_name, device=device)

# 3️⃣ Define expanded training data
# Each intent has ~20 queries

data = [
    # Fact intents
    ("What is the primary aim of this study?", "fact"),
    ("Who are the respondents in the primary research?", "fact"),
    ("Which frameworks guide the integration of AI in healthcare?", "fact"),
    ("What are the main research questions?", "fact"),
    ("What methodologies were used for data collection?", "fact"),
    ("Which hospitals were included in the study?", "fact"),
    ("What is the scope of Literature Review I?", "fact"),
    ("What is the scope of Literature Review II?", "fact"),
    ("What theoretical frameworks are used for hospital AI adoption?", "fact"),
    ("Which AI models are used for patient inflow prediction?", "fact"),
    ("Who oversees AI implementation in hospitals?", "fact"),
    ("Which countries were compared in the study?", "fact"),
    ("What are the inclusion criteria for respondents?", "fact"),
    ("Which departments are affected by AI resource allocation?", "fact"),
    ("What types of hospital data were used?", "fact"),
    ("Which predictive analytics methods were applied?", "fact"),
    ("What are the key variables in AI forecasting models?", "fact"),
    ("Which chronic diseases were analyzed for prediction?", "fact"),
    ("What kind of NLP models were referenced?", "fact"),
    ("What are the main challenges identified in LMIC adoption?", "fact"),

    # Analysis intents
    ("Analyze how AI affects hospital workflow optimization.", "analysis"),
    ("Evaluate the impact of AI-driven predictive analytics on patient outcomes.", "analysis"),
    ("Compare traditional hospital inventory management with AI-based approaches.", "analysis"),
    ("Examine the challenges of AI adoption in low- and middle-income countries.", "analysis"),
    ("Assess how predictive analytics improve patient inflow prediction.", "analysis"),
    ("Evaluate AI's effect on staff scheduling efficiency.", "analysis"),
    ("Analyze how AI reduces resource wastage in hospitals.", "analysis"),
    ("Compare reactive vs proactive hospital resource allocation.", "analysis"),
    ("Examine ethical concerns in predictive analytics deployment.", "analysis"),
    ("Evaluate human-AI collaboration in clinical decision-making.", "analysis"),
    ("Analyze the effect of digital infrastructure on AI adoption.", "analysis"),
    ("Compare LMIC and high-income country AI implementation.", "analysis"),
    ("Evaluate AI's role in chronic disease management.", "analysis"),
    ("Analyze workflow automation using AI in hospital operations.", "analysis"),
    ("Examine the reliability of AI predictions in rural hospitals.", "analysis"),
    ("Assess the impact of NLP on clinical data analysis.", "analysis"),
    ("Compare AI forecasting accuracy across different hospital units.", "analysis"),
    ("Analyze cost-benefit of AI-driven inventory management.", "analysis"),
    ("Evaluate patient prioritization improvements due to AI.", "analysis"),
    ("Assess policy and regulatory implications for AI adoption.", "analysis"),

    # Summary intents
    ("Summarize the methodology of the study.", "summary"),
    ("Provide a brief overview of Literature Review I.", "summary"),
    ("Give key findings about AI in preventive healthcare.", "summary"),
    ("Summarize the conclusions regarding human-AI collaboration.", "summary"),
    ("Provide a concise summary of AI applications in hospitals.", "summary"),
    ("Summarize challenges of AI adoption in LMICs.", "summary"),
    ("Provide an overview of predictive analytics frameworks.", "summary"),
    ("Summarize inventory management improvements using AI.", "summary"),
    ("Give key points of patient inflow prediction models.", "summary"),
    ("Summarize workflow optimization results with AI.", "summary"),
    ("Provide a summary of ethical concerns in AI adoption.", "summary"),
    ("Summarize findings on clinical decision support using AI.", "summary"),
    ("Summarize results of NLP-based clinical data extraction.", "summary"),
    ("Provide a summary of chronic disease predictive models.", "summary"),
    ("Summarize overall benefits of AI in hospitals.", "summary"),
    ("Provide key insights from hospital case studies.", "summary"),
    ("Summarize adoption challenges in Indian hospitals.", "summary"),
    ("Provide a summary of staff scheduling improvements.", "summary"),
    ("Summarize patient outcome enhancements using AI.", "summary"),
    ("Summarize main contributions of the study.", "summary"),

    # Visual intents
    ("Show a diagram of the Predictive Analytics Framework.", "visual"),
    ("Provide a chart for patient inflow prediction.", "visual"),
    ("Display a table of AI applications in hospital inventory management.", "visual"),
    ("Illustrate workflow optimization in hospitals using AI.", "visual"),
    ("Show a graph of AI impact on patient outcomes.", "visual"),
    ("Provide a chart comparing traditional vs AI inventory management.", "visual"),
    ("Show visualization of staff scheduling improvements.", "visual"),
    ("Display diagram of human-AI collaboration.", "visual"),
    ("Provide graph of predictive analytics framework components.", "visual"),
    ("Show chart of chronic disease predictions using AI.", "visual"),
    ("Illustrate resource allocation using AI dashboards.", "visual"),
    ("Display visualization of NLP extraction results.", "visual"),
    ("Show AI adoption levels in hospitals.", "visual"),
    ("Provide visual summary of LMIC challenges.", "visual"),
    ("Show workflow automation diagram.", "visual"),
    ("Display chart of inventory stock optimization.", "visual"),
    ("Provide visual of patient prioritization improvements.", "visual"),
    ("Show graph of emergency department patient inflow.", "visual"),
    ("Display visualization of predictive model accuracy.", "visual"),
    ("Illustrate end-to-end AI hospital management system.", "visual")
]

texts, labels = zip(*data)

# Map labels to indices
label_to_idx = {lbl: i for i, lbl in enumerate(DEFAULT_INTENTS)}
y = torch.tensor([label_to_idx[lbl] for lbl in labels], dtype=torch.long, device=device)

# 4️⃣ Compute embeddings
X = torch.tensor(embedder.encode(list(texts), convert_to_tensor=True), dtype=torch.float32, device=device)

# 5️⃣ Initialize model
input_dim = X.shape[1]  # 384 for all-MiniLM-L6-v2
hidden_dim = 128
num_classes = len(DEFAULT_INTENTS)
model = IntentClassifier(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device)

# 6️⃣ Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 100

# 7️⃣ Training loop
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 20 == 0:
        pred = torch.argmax(outputs, dim=1)
        acc = (pred == y).float().mean().item()
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {loss.item():.4f}, Accuracy: {acc:.2f}")

# 8️⃣ Save trained model
IntentAgent.save_trained_model(model)

# ✅ 9️⃣ Test the trained classifier
intent_agent = IntentAgent()
test_queries = [
    "Can you summarize the financial report?",
    "Show the chart of yearly revenue",
    "Why did production cost increase?",
    "Where is the headquarters located?",
    "Illustrate workflow optimization in hospitals using AI.",
    "What are the main research questions?",
    "Provide a table of AI applications in inventory management."
]

for q in test_queries:
    pred_intent = intent_agent.predict(q)
    print(f"Query: {q} -> Predicted Intent: {pred_intent}")


: 