In [2]:
!pip install kaleido networkx numpy pandas plotly seaborn shap scikit-learn matplotlib kaleido typing-extensions

# -*- coding: utf-8 -*-
"""Enhanced Anxiety Intervention Explainability with Fine-Tuned LLM (Google Drive Integration)

This notebook simulates fine-tuning a pre-trained language model to generate
nuanced and context-aware interpretations of visualizations and statistical
analyses, enhancing the explainability of anxiety intervention results.  It
saves all outputs to a specified Google Drive folder.

Workflow:
1. Mount Google Drive: Connects to your Google Drive.
2. Data Loading and Validation: Load and validate synthetic anxiety data.
3. Data Preprocessing: One-hot encode groups and scale numerical features.
4. SHAP Value Analysis: Quantify feature importance.
5. Data Visualization: Generate KDE, Violin, Parallel Coordinates, and Hypergraph plots.
6. Statistical Summary: Perform bootstrap analysis and generate summary statistics.
7. Fine-Tuned LLM Insights Report: Synthesize findings using simulated LLMs.

Keywords: Fine-Tuning, Transformers, LLM Interpretation, Explainability, Anxiety Intervention, SHAP, Data Visualization, Google Drive
"""

import os
import warnings
from io import StringIO
from typing import List, Dict, Tuple, Callable, Optional

import networkx as nx
import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
import shap
from matplotlib import pyplot as plt
from scipy.stats import bootstrap
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import MinMaxScaler

# Google Drive integration
from google.colab import drive

# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="plotly")


# --- Constants ---
#  This is now relative to the mounted drive.
OUTPUT_PATH = "/content/drive/MyDrive/output_anxiety_latent_causal_graph/"
PARTICIPANT_ID_COLUMN = "participant_id"
GROUP_COLUMN = "group"
ANXIETY_PRE_COLUMN = "anxiety_pre"
ANXIETY_POST_COLUMN = "anxiety_post"
MODEL_GROK_NAME = "grok-base"
MODEL_CLAUDE_NAME = "claude-3.7-sonnet"
MODEL_GROK_ENHANCED_NAME = "grok-enhanced"  # Simulated fine-tuned model
LINE_WIDTH = 2.5
NEON_COLORS = ["#FF00FF", "#00FFFF", "#FFFF00", "#00FF00"]
BOOTSTRAP_RESAMPLES = 500


# --- Helper Functions ---

def create_output_directory(path: str) -> None:
    """Creates the output directory on Google Drive if it doesn't exist."""
    os.makedirs(path, exist_ok=True)


def load_data_from_synthetic_string(csv_string: str) -> Optional[pd.DataFrame]:
    """Loads data from a synthetic CSV string."""
    if not csv_string.strip():
        raise ValueError("The input CSV string is empty.")
    return pd.read_csv(StringIO(csv_string))


def validate_dataframe(df: pd.DataFrame, required_columns: List[str]) -> None:
    """Validates the DataFrame: required columns, data types, IDs, groups, and ranges."""
    if df is None:
        raise ValueError("DataFrame is None.")

    missing_columns = set(required_columns) - set(df.columns)
    if missing_columns:
        raise ValueError(f"Missing columns: {missing_columns}")

    for col in required_columns:
        if col not in (PARTICIPANT_ID_COLUMN, GROUP_COLUMN):
            if not pd.api.types.is_numeric_dtype(df[col]):
                raise TypeError(f"Non-numeric values in column: {col}")

    if df[PARTICIPANT_ID_COLUMN].duplicated().any():
        raise ValueError("Duplicate participant IDs found.")

    valid_groups = ["Group A", "Group B", "Control"]
    invalid_groups = df[~df[GROUP_COLUMN].isin(valid_groups)][GROUP_COLUMN].unique()
    if invalid_groups.size > 0:
        raise ValueError(f"Invalid group labels: {invalid_groups}")

    for col in (ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN):
        if not (0 <= df[col].min() <= 10 and 0 <= df[col].max() <= 10):
            raise ValueError(f"Anxiety scores in '{col}' out of range (0-10).")


def analyze_text_with_llm(text: str, model_name: str) -> str:
    """Simulates text analysis with different LLMs."""
    text_lower = text.lower()

    if model_name == MODEL_GROK_NAME:
        if "causal graph" in text_lower:
            return "Grok-base: Causal graph shows basic relationships."
        if "shap summary" in text_lower:
            return "Grok-base: SHAP values indicate feature importance generally."
        return f"Grok-base: Initial analysis on '{text}'."

    if model_name == MODEL_CLAUDE_NAME:
        if "kde plot" in text_lower:
            return "Claude 3.7: KDE plot visually compares anxiety distributions, revealing overlaps and separations."
        if "violin plot" in text_lower:
            return "Claude 3.7: Violin plot details distribution shapes, highlighting group-specific variations."
        return f"Claude 3.7: Enhanced visual analysis on '{text}'."

    if model_name == MODEL_GROK_ENHANCED_NAME:
        if "shap summary" in text_lower:
            return "Grok-Enhanced (Fine-Tuned): SHAP summary reveals pre-anxiety as dominant, but group membership also contributes, suggesting a moderated effect."
        if "parallel coordinates" in text_lower:
            return "Grok-Enhanced (Fine-Tuned): Parallel coordinates show individual trajectories and group patterns, with clear anxiety reduction but variability."
        if "causal graph" in text_lower:
            return "Grok-Enhanced (Fine-Tuned): Causal graph suggests a moderated pathway; group membership influences the impact of pre-anxiety."
        if "kde plot" in text_lower:
            return "Grok-Enhanced (Fine-Tuned): KDE plot shows a shift towards lower anxiety, more pronounced in Group A."
        if "violin plot" in text_lower:
            return "Grok-Enhanced (Fine-Tuned): Violin plot shows decreased median anxiety, but differing distributions suggest varying responsiveness."
        if "hypergraph" in text_lower:
            return "Grok-Enhanced (Fine-Tuned): Hypergraph highlights clusters with similar profiles, suggesting subgroups and personalized strategies."
        return f"Grok-Enhanced (Fine-Tuned): Context-aware analysis on '{text}'. Provides nuanced, actionable insights."

    raise ValueError(f"Model '{model_name}' not supported.")


def scale_data(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
    """Scales specified columns using MinMaxScaler."""
    if df.empty:
        raise ValueError("Input DataFrame is empty.")
    for col in columns:
        if not pd.api.types.is_numeric_dtype(df[col]):
            raise ValueError(f"Column '{col}' is not numeric.")
    scaler = MinMaxScaler()
    df[columns] = scaler.fit_transform(df[columns])
    return df


def calculate_shap_values(df: pd.DataFrame, feature_columns: List[str], target_column: str, output_path: str) -> str:
    """Calculates and plots SHAP values."""
    if not all(col in df.columns for col in feature_columns):
        raise ValueError("Feature columns not found in DataFrame.")
    if target_column not in df.columns or not pd.api.types.is_numeric_dtype(df[target_column]):
        raise ValueError("Target column issue.")

    model_rf = RandomForestRegressor(random_state=42)
    model_rf.fit(df[feature_columns], df[target_column])
    explainer = shap.TreeExplainer(model_rf)
    shap_values = explainer.shap_values(df[feature_columns])

    plt.figure(figsize=(10, 8))
    plt.style.use('dark_background')
    shap.summary_plot(shap_values, df[feature_columns], show=False, color_bar=True)
    plt.savefig(os.path.join(output_path, 'shap_summary.png'))
    plt.close()
    return f"SHAP summary for features {feature_columns} predicting {target_column}"


def create_kde_plot(df: pd.DataFrame, column1: str, column2: str, output_path: str, colors: List[str]) -> str:
    """Creates a KDE plot comparing two columns."""
    for col in (column1, column2):
        if col not in df.columns or not pd.api.types.is_numeric_dtype(df[col]):
            raise ValueError(f"Invalid column '{col}' for KDE plot.")

    plt.figure(figsize=(10, 6))
    plt.style.use('dark_background')
    sns.kdeplot(data=df[column1], color=colors[0], label=column1.capitalize(), linewidth=LINE_WIDTH)
    sns.kdeplot(data=df[column2], color=colors[1], label=column2.capitalize(), linewidth=LINE_WIDTH)
    plt.title('KDE Plot of Anxiety Levels', color='white')
    plt.legend(facecolor='black', edgecolor='white', labelcolor='white')
    plt.savefig(os.path.join(output_path, 'kde_plot.png'))
    plt.close()
    return f"KDE plot visualizing distributions of {column1} and {column2}"


def create_violin_plot(df: pd.DataFrame, group_column: str, y_column: str, output_path: str, colors: List[str]) -> str:
    """Creates a violin plot."""
    if group_column not in df.columns:
        raise ValueError(f"Group column '{group_column}' not found.")
    if y_column not in df.columns or not pd.api.types.is_numeric_dtype(df[y_column]):
        raise ValueError(f"Invalid Y column '{y_column}' for violin plot.")

    plt.figure(figsize=(10, 6))
    plt.style.use('dark_background')
    sns.violinplot(data=df, x=group_column, y=y_column, palette=colors, linewidth=LINE_WIDTH)
    plt.title('Violin Plot of Anxiety Distribution by Group', color='white')
    plt.savefig(os.path.join(output_path, 'violin_plot.png'))
    plt.close()
    return f"Violin plot showing {y_column} across {group_column}"


def create_parallel_coordinates_plot(df: pd.DataFrame, group_column: str, anxiety_pre_column: str, anxiety_post_column: str, output_path: str, colors: List[str]) -> str:
    """Creates a parallel coordinates plot."""
    required_cols = (group_column, anxiety_pre_column, anxiety_post_column)
    if not all(col in df.columns for col in required_cols) or not all(pd.api.types.is_numeric_dtype(df[col]) for col in required_cols[1:]):
        raise ValueError("Invalid columns for parallel coordinates plot.")

    plot_df = df[list(required_cols)].copy()
    plot_df = pd.get_dummies(plot_df, columns=[group_column], prefix=group_column)
    encoded_group_cols = [col for col in plot_df.columns if col.startswith(f"{group_column}_")]

    def get_group_color(row):
        for i, col in enumerate(encoded_group_cols):
            if row[col] == 1:
                return colors[i % len(colors)]
        return 'gray'

    plot_df['color'] = plot_df[encoded_group_cols].apply(get_group_color, axis=1)

    fig = px.parallel_coordinates(plot_df, color='color', dimensions=[anxiety_pre_column, anxiety_post_column],
                                  title="Anxiety Levels: Pre- vs Post-Intervention by Group")
    fig.update_layout(plot_bgcolor='black', paper_bgcolor='black', font_color='white', title_font_size=16)
    # Use plotly's io module to save the image, handling potential kaleido issues
    try:
        fig.write_image(os.path.join(output_path, 'parallel_coordinates_plot.png'))
    except ValueError as e:
        if "kaleido" in str(e).lower():
            print("Error: Kaleido is required for image export. Please install it using 'pip install kaleido'.")
            print("Skipping saving the parallel coordinates plot.")
        else:  # Re-raise if it's a different ValueError
            raise
    return "Parallel coordinates plot of anxiety pre vs post intervention by group"


def visualize_hypergraph(df: pd.DataFrame, anxiety_pre_column: str, anxiety_post_column: str, output_path: str, colors: List[str]) -> str:
    """Visualizes relationships using a hypergraph."""
    required_cols = (PARTICIPANT_ID_COLUMN, anxiety_pre_column, anxiety_post_column)
    if not all(col in df.columns for col in required_cols):
        raise ValueError("Required columns not found for hypergraph.")

    G = nx.Graph()
    participant_ids = df[PARTICIPANT_ID_COLUMN].tolist()
    G.add_nodes_from(participant_ids, bipartite=0)
    feature_sets = {
        "anxiety_pre": df[PARTICIPANT_ID_COLUMN][df[anxiety_pre_column] > df[anxiety_pre_column].mean()].tolist(),
        "anxiety_post": df[PARTICIPANT_ID_COLUMN][df[anxiety_post_column] > df[anxiety_post_column].mean()].tolist()
    }
    feature_nodes = list(feature_sets.keys())
    G.add_nodes_from(feature_nodes, bipartite=1)
    for feature, participants in feature_sets.items():
        for participant in participants:
            G.add_edge(participant, feature)
    pos = nx.bipartite_layout(G, participant_ids)
    color_map = [colors[0] if node in participant_ids else colors[1] for node in G]
    plt.figure(figsize=(12, 10))
    plt.style.use('dark_background')
    nx.draw(G, pos, with_labels=True, node_color=color_map, font_color="white", edge_color="gray", width=LINE_WIDTH, node_size=700, font_size=10)
    plt.title("Hypergraph Representation of Anxiety Patterns", color="white")
    plt.savefig(os.path.join(output_path, "hypergraph.png"))
    plt.close()
    return "Hypergraph visualizing participant relationships."


def perform_bootstrap(data: pd.Series, statistic: Callable, n_resamples: int = BOOTSTRAP_RESAMPLES) -> Tuple[Optional[float], Optional[float]]:
    """Performs bootstrap resampling and returns the confidence interval."""
    if data.empty or not pd.api.types.is_numeric_dtype(data):
        raise ValueError("Invalid input data for bootstrap.")

    bootstrap_result = bootstrap((data,), statistic, n_resamples=n_resamples, method='percentile', random_state=42)
    return bootstrap_result.confidence_interval.low, bootstrap_result.confidence_interval.high


def save_summary(df: pd.DataFrame, bootstrap_ci: Tuple[Optional[float], Optional[float]], output_path: str) -> str:
    """Calculates and saves summary statistics."""
    ci_string = f"[{bootstrap_ci[0]:.4f}, {bootstrap_ci[1]:.4f}]" if all(x is not None for x in bootstrap_ci) else "Could not calculate Bootstrap CI"
    summary_stats_text = df.describe().to_string() + f"\nBootstrap CI for anxiety_post mean: {ci_string}"
    with open(os.path.join(output_path, 'summary.txt'), 'w') as f:
        f.write(summary_stats_text)
    return summary_stats_text


def generate_insights_report(summary_stats_text: str, shap_analysis_info: str, kde_plot_desc: str, violin_plot_desc: str, parallel_coords_desc: str, hypergraph_desc: str, output_path: str) -> None:
    """Generates a comprehensive insights report using simulated LLMs."""
    grok_insights = (
        analyze_text_with_llm(f"Analyze summary statistics:\n{summary_stats_text}", MODEL_GROK_NAME) + "\n\n" +
        analyze_text_with_llm(f"Explain SHAP summary: {shap_analysis_info}", MODEL_GROK_NAME)
    )
    claude_insights = (
        analyze_text_with_llm(f"Interpret KDE plot: {kde_plot_desc}", MODEL_CLAUDE_NAME) + "\n\n" +
        analyze_text_with_llm(f"Interpret Violin plot: {violin_plot_desc}", MODEL_CLAUDE_NAME) + "\n\n" +
        analyze_text_with_llm(f"Interpret Parallel Coordinates Plot: {parallel_coords_desc}", MODEL_CLAUDE_NAME) + "\n\n" +
        analyze_text_with_llm(f"Interpret Hypergraph: {hypergraph_desc}", MODEL_CLAUDE_NAME)
    )
    grok_enhanced_insights = analyze_text_with_llm("Provide enhanced, context-aware insights on anxiety intervention effectiveness, integrating all analyses.", MODEL_GROK_ENHANCED_NAME)

    combined_insights = f"""
Combined Insights Report: Anxiety Intervention Analysis

Grok-base Analysis:
{grok_insights}

Claude 3.7 Sonnet Analysis:
{claude_insights}

Grok-Enhanced Analysis (Fine-Tuned):
{grok_enhanced_insights}

Synthesized Summary:
This report presents a synthesized analysis of anxiety intervention effectiveness, leveraging a Mixture of Experts approach with simulated fine-tuning of the Grok-Enhanced LLM.  Grok-base provides initial interpretations. Claude 3.7 offers enhanced visual analysis. The Grok-Enhanced model, simulating fine-tuning, provides significantly more nuanced and context-aware interpretations. It identifies pre-anxiety as the dominant predictor but acknowledges the moderating role of group membership. The combined analyses provide a robust and explainable understanding of the intervention's effects.
"""
    with open(os.path.join(output_path, 'insights.txt'), 'w') as f:
        f.write(combined_insights)
    print(f"Insights saved to: {os.path.join(output_path, 'insights.txt')}")


# --- Main Script ---

if __name__ == "__main__":
    # Mount Google Drive
    drive.mount('/content/drive')

    create_output_directory(OUTPUT_PATH)

    # More varied and slightly larger synthetic dataset
    synthetic_dataset = """
participant_id,group,anxiety_pre,anxiety_post
P001,Group A,4,2
P002,Group A,3,1
P003,Group A,5,1
P004,Group B,6,5
P005,Group B,5,4
P006,Group B,7,7
P007,Control,3,3
P008,Control,4,5
P009,Control,2,2
P010,Control,5,4
P011,Group A,6,3
P012,Group A,2,0
P013,Group B,8,6
P014,Group B,7,5
P015,Control,4,4
"""
    df = load_data_from_synthetic_string(synthetic_dataset)
    required_columns = [PARTICIPANT_ID_COLUMN, GROUP_COLUMN, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN]
    validate_dataframe(df, required_columns)

    df_original = df.copy()  # Keep original for visualizations

    df = pd.get_dummies(df, columns=[GROUP_COLUMN], prefix=GROUP_COLUMN, drop_first=False)
    encoded_group_cols = [col for col in df.columns if col.startswith(f"{GROUP_COLUMN}_")]  # Corrected line
    df = scale_data(df, [ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN] + encoded_group_cols)

    # SHAP analysis
    shap_feature_columns = encoded_group_cols + [ANXIETY_PRE_COLUMN]
    shap_analysis_info = calculate_shap_values(df, shap_feature_columns, ANXIETY_POST_COLUMN, OUTPUT_PATH)

    # Visualizations (using df_original for plots needing original group labels)
    kde_plot_desc = create_kde_plot(df, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, NEON_COLORS[:2])
    violin_plot_desc = create_violin_plot(df_original, GROUP_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, NEON_COLORS)
    parallel_coords_desc = create_parallel_coordinates_plot(df_original, GROUP_COLUMN, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, NEON_COLORS)
    hypergraph_desc = visualize_hypergraph(df_original, ANXIETY_PRE_COLUMN, ANXIETY_POST_COLUMN, OUTPUT_PATH, NEON_COLORS[:2])

    # Statistical analysis
    bootstrap_ci = perform_bootstrap(df[ANXIETY_POST_COLUMN], np.mean)
    summary_stats_text = save_summary(df, bootstrap_ci, OUTPUT_PATH)

    # Generate insights report
    generate_insights_report(summary_stats_text, shap_analysis_info, kde_plot_desc, violin_plot_desc, parallel_coords_desc, hypergraph_desc, OUTPUT_PATH)

    print("Execution completed successfully - LLM Interpretation Enhanced Notebook (Refactored with Google Drive).")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


  sns.violinplot(data=df, x=group_column, y=y_column, palette=colors, linewidth=LINE_WIDTH)


Insights saved to: /content/drive/MyDrive/output_anxiety_latent_causal_graph/insights.txt
Execution completed successfully - LLM Interpretation Enhanced Notebook (Refactored with Google Drive).
