This notebook demonstrates how a plotly dashbaord can be built using the example of analyzing PISA educational indicators and outcomes (see https://www.kaggle.com/datasets/yummykaggle/pisa-school-level-indicators-and-outcomes)

In [None]:
!pip install plotly dash pandas numpy scipy

Collecting dash
  Downloading dash-3.2.0-py3-none-any.whl.metadata (10 kB)
Collecting retrying (from dash)
  Downloading retrying-1.4.2-py3-none-any.whl.metadata (5.5 kB)
Downloading dash-3.2.0-py3-none-any.whl (7.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading retrying-1.4.2-py3-none-any.whl (10 kB)
Installing collected packages: retrying, dash
Successfully installed dash-3.2.0 retrying-1.4.2


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
import plotly.io as pio
from google.colab import auth
from google.cloud import bigquery
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Authenticate and setup
auth.authenticate_user()
PROJECT_ID = "mgmt599-dn3-final-project"
REGION = "us-central1"
BQ_DATASET = "edu"
client = bigquery.Client(project=PROJECT_ID)

### Data loading functions

In [None]:
def load_pisa_data(project_id, dataset, client, limit=None):
    """Load all necessary data from BigQuery."""
    print("📊 Loading PISA data...")

    # Load main PISA data
    if limit:
        pisa_query = f"""
        SELECT *
        FROM `{project_id}.{dataset}.pisa_data`
        WHERE math IS NOT NULL AND read IS NOT NULL AND sci IS NOT NULL
        LIMIT {limit}
        """
    else:
        pisa_query = f"""
        SELECT *
        FROM `{project_id}.{dataset}.pisa_data`
        WHERE math IS NOT NULL AND read IS NOT NULL AND sci IS NOT NULL
        """

    pisa_data = client.query(pisa_query).to_dataframe()

    # Load codebooks
    codebooks_query = f"""
    SELECT field_id, field_name
    FROM `{project_id}.{dataset}.pisa_codebooks`
    """
    codebooks_data = client.query(codebooks_query).to_dataframe()

    print(f"✅ Loaded {len(pisa_data)} PISA records")
    print(f"✅ Loaded {len(codebooks_data)} codebook entries")

    return pisa_data, codebooks_data

def load_model_results(project_id, dataset, client):
    """Load model evaluation results."""
    print("🤖 Loading model results...")

    model_results = {}
    model_types = ['pisa_reg_lasso_model', 'pisa_rand_forest_model']
    subjects = ['math', 'read', 'sci']

    for model_type in model_types:
        model_results[model_type] = {}

        for subject in subjects:
            model_name = f"{model_type}_{subject}"

            # Get trial info
            trial_query = f"""
            SELECT
                trial_id,
                hparam_tuning_evaluation_metrics.r2_score as r2_score,
                hyperparameters,
                eval_loss,
                training_loss
            FROM ML.TRIAL_INFO(MODEL `{project_id}.{dataset}.{model_name}`)
            ORDER BY hparam_tuning_evaluation_metrics.r2_score DESC
            """

            try:
                trial_results = client.query(trial_query).to_dataframe()
                if not trial_results.empty:
                    # Store only the best trial (first row since ordered by r2_score DESC)
                    best_trial = trial_results.iloc[0:1]  # Keep as DataFrame with 1 row
                    model_results[model_type][subject] = best_trial
                    print(f"✅ Loaded best trial for {model_name} (R² = {best_trial['r2_score'].iloc[0]:.3f})")
                else:
                    model_results[model_type][subject] = pd.DataFrame()
                    print(f"⚠️ No trials found for {model_name}")
            except Exception as e:
                print(f"❌ Error loading {model_name}: {e}")
                model_results[model_type][subject] = pd.DataFrame()

    return model_results

### Dashboard class

In [None]:
class PISADashboard:

    MISSING_DATA_COLOR = '#f87171'

    def __init__(self, project_id, dataset, client, pisa_data, codebooks_data, model_results=None, theme='plotly_white'):
        """
        Initialize PISA Dashboard with pre-loaded data.

        Args:
            project_id: BigQuery project ID
            dataset: BigQuery dataset name
            client: BigQuery client
            pisa_data: Pre-loaded PISA data DataFrame
            codebooks_data: Pre-loaded codebooks DataFrame
            model_results: Pre-loaded model results (optional)
            theme: Plotly theme
        """
        self.project_id = project_id
        self.dataset = dataset
        self.client = client
        self.theme = theme

        # Store the pre-loaded data
        self.data = {
            'pisa': pisa_data,
            'codebooks': codebooks_data
        }

        # Store model results if provided
        self.model_results = model_results if model_results is not None else {}

        # Define custom templates FIRST
        self.setup_custom_templates()

        # THEN set the default template
        pio.templates.default = theme

        print(f"🎯 Dashboard initialized with {len(self.data['pisa'])} PISA records")
        if self.model_results:
            print(f"🤖 Dashboard has model results for {len(self.model_results)} model types")

    def setup_custom_templates(self):
        """Setup custom Plotly templates with enhanced styling."""

        # Custom Academic Template
        pio.templates["academic"] = go.layout.Template(
            layout=go.Layout(
                font=dict(family="Arial, sans-serif", size=12, color="#2c3e50"),
                plot_bgcolor='#f8f9fa',
                paper_bgcolor='white',
                colorway=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6', '#1abc9c'],
                xaxis=dict(
                    showline=True, linewidth=1, linecolor='#bdc3c7',
                    mirror=True, showgrid=True, gridcolor='#ecf0f1'
                ),
                yaxis=dict(
                    showline=True, linewidth=1, linecolor='#bdc3c7',
                    mirror=True, showgrid=True, gridcolor='#ecf0f1'
                ),
                title=dict(
                    font=dict(size=20, color='#2c3e50', family='Arial Bold'),
                    x=0.5,
                    xanchor='center'
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                )
            )
        )

        # Custom Professional Template
        pio.templates["professional"] = go.layout.Template(
            layout=go.Layout(
                font=dict(family="Segoe UI, sans-serif", size=11, color="#1f2937"),
                plot_bgcolor='white',
                paper_bgcolor='#fafafa',
                colorway=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b'],
                xaxis=dict(
                    showline=True, linewidth=2, linecolor='#374151',
                    mirror=True, showgrid=True, gridcolor='#e5e7eb',
                    title=dict(font=dict(size=12, color='#374151'))
                ),
                yaxis=dict(
                    showline=True, linewidth=2, linecolor='#374151',
                    mirror=True, showgrid=True, gridcolor='#e5e7eb',
                    title=dict(font=dict(size=12, color='#374151'))
                ),
                title=dict(
                    font=dict(size=18, color='#1f2937', family='Segoe UI Bold'),
                    x=0.5,
                    xanchor='center'
                )
            )
        )

        # Custom Dark Template
        pio.templates["dark_modern"] = go.layout.Template(
            layout=go.Layout(
                font=dict(family="Inter, sans-serif", size=12, color="#e5e7eb"),
                plot_bgcolor='#1f2937',
                paper_bgcolor='#111827',
                colorway=['#60a5fa', '#f87171', '#34d399', '#fbbf24', '#a78bfa', '#fb7185'],
                xaxis=dict(
                    showline=True, linewidth=1, linecolor='#4b5563',
                    mirror=True, showgrid=True, gridcolor='#374151'
                ),
                yaxis=dict(
                    showline=True, linewidth=1, linecolor='#4b5563',
                    mirror=True, showgrid=True, gridcolor='#374151'
                ),
                title=dict(
                    font=dict(size=20, color='#f9fafb', family='Inter Bold'),
                    x=0.5,
                    xanchor='center'
                ),
                legend=dict(
                    font=dict(color='#e5e7eb')
                )
            )
        )

    def set_theme(self, theme):
        """Change the dashboard theme."""
        self.theme = theme
        pio.templates.default = theme

    def add_custom_title(self, fig, title_text, title_color=None, bg_color=None):
        """Add a custom styled title to any figure based on the current theme."""

        title_color = title_color or "#ffe43c"
        bg_color = bg_color or "#2c3e50"
        border_color = "#34495e"

        # Add the text annotation on top of the rectangle
        fig.add_annotation(
            text=f"<b>{title_text}</b>",
            xref="paper", yref="paper",
            x=0, y=1.1,  # Align to the left
            xanchor="left", yanchor="middle",
            showarrow=False,
            font=dict(color=title_color, size=20, family="Arial Black"),
            # Remove bgcolor and border since we're using the shape for background
            bgcolor=None,
            bordercolor=None
        )

        # Adjust top margin for title
        fig.update_layout(margin=dict(t=80, b=40, l=40, r=40))

    def get_feature_importance(self, model_type, subject):
        """Get feature importance for a specific model."""
        if 'lasso' in model_type:
            # For linear regression, use ML.WEIGHTS
            query = f"""
            WITH best_trial AS (
              SELECT trial_id
              FROM ML.TRIAL_INFO(MODEL `{self.project_id}.{self.dataset}.{model_type}_{subject}`)
              ORDER BY hparam_tuning_evaluation_metrics.r2_score DESC
              LIMIT 1
            )
            SELECT
              w.processed_input as feature,
              w.weight,
              ABS(w.weight) as abs_weight
            FROM ML.WEIGHTS(MODEL `{self.project_id}.{self.dataset}.{model_type}_{subject}`) w
            JOIN best_trial b ON w.trial_id = b.trial_id
            WHERE w.processed_input IS NOT NULL
            ORDER BY ABS(w.weight) DESC
            LIMIT 20
            """
        else:
            # For random forest, use ML.FEATURE_IMPORTANCE
            query = f"""
            WITH best_trial AS (
              SELECT trial_id
              FROM ML.TRIAL_INFO(MODEL `{self.project_id}.{self.dataset}.{model_type}_{subject}`)
              ORDER BY hparam_tuning_evaluation_metrics.r2_score DESC
              LIMIT 1
            )
            SELECT
              fi.feature,
              fi.importance_weight as weight,
              fi.importance_gain
            FROM ML.FEATURE_IMPORTANCE(MODEL `{self.project_id}.{self.dataset}.{model_type}_{subject}`) fi
            JOIN best_trial b ON fi.trial_id = b.trial_id
            WHERE fi.feature IS NOT NULL
            ORDER BY fi.importance_weight DESC
            LIMIT 20
            """

        try:
            return self.client.query(query).to_dataframe()
        except Exception as e:
            print(f"Error getting feature importance: {e}")
            return pd.DataFrame()

    def create_data_overview(self):
        """Create data overview visualizations."""
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Educational Outcomes\' Distributions', 'Educational Outcomes\' Correlations',
                          'Missing Data Analysis', 'Sample Countries'),
            specs=[[{"secondary_y": False}, {"secondary_y": False}],
                   [{"secondary_y": False}, {"secondary_y": False}]],
            vertical_spacing=0.12,
            horizontal_spacing=0.1
        )

        # 1. Score distributions
        for i, subject in enumerate(['math', 'read', 'sci']):
            fig.add_trace(
                go.Histogram(
                    x=self.data['pisa'][subject],
                    name=f'{subject.upper()} Scores',
                    opacity=0.7,
                    nbinsx=30
                ),
                row=1, col=1
            )

        # 2. Correlation heatmap
        corr_data = self.data['pisa'][['math', 'read', 'sci']].corr()
        fig.add_trace(
            go.Heatmap(
                z=corr_data.values,
                x=['Math', 'Read', 'Science'],
                y=['Math', 'Read', 'Science'],
                colorscale='RdYlBu_r',
                text=corr_data.round(3).values,
                texttemplate="%{text}",
                textfont={"size": 12},
                showscale=False
            ),
            row=1, col=2
        )

        # 3. Missing data analysis
        missing_data = self.data['pisa'].isnull().sum().head(10)
        hover_texts = []
        y_labels = []
        for field_id in missing_data.index:
            missing_count = missing_data[field_id]

            # Get field name from codebook
            field_info = self.data['codebooks'][self.data['codebooks']['field_id'] == field_id]
            if not field_info.empty:
                field_name = field_info['field_name'].iloc[0]
            else:
                field_name = "No description available"

            # Create hover text and y-axis label
            hover_text = f"{missing_count}, {field_id} {field_name}"
            hover_texts.append(hover_text)
            y_labels.append(field_id)  # Keep field_id as y-axis label

        fig.add_trace(
            go.Bar(
                x=missing_data.values,
                y=y_labels,
                orientation='h',
                name='Missing Values',
                width=0.6,
                marker_color=self.MISSING_DATA_COLOR,
                hovertemplate='%{hovertext}<extra></extra>',  # Custom hover template
                hovertext=hover_texts  # Custom hover text
            ),
            row=2, col=1
        )

        # Update the y-axis for this subplot to show full field_id labels
        fig.update_yaxes(
            tickmode='linear',
            automargin=True,  # Automatically adjust margins for labels
            row=2, col=1
        )

        # 4. Sample by country (if CNT column exists)
        if 'CNT' in self.data['pisa'].columns:
            country_counts = self.data['pisa']['CNT'].value_counts().head(10)
            fig.add_trace(
                go.Bar(
                    x=country_counts.index,
                    y=country_counts.values,
                    name='Sample Size by Country',
                    marker_color='#00ddff'
                ),
                row=2, col=2
            )

        # Apply template and add custom title
        fig.update_layout(template=self.theme, height=700, showlegend=True)
        self.add_custom_title(fig, "PISA Dataset Overview")

        return fig

    def create_field_explorer(self, selected_field=None):
        """Create an interactive field explorer dashboard."""

        # Get all available fields from the PISA data
        available_fields = list(self.data['pisa'].columns)

        # Create dropdown options with field names from codebooks
        field_options = []
        for field in available_fields:
            # Get field description from codebook
            field_info = self.data['codebooks'][self.data['codebooks']['field_id'] == field]
            if not field_info.empty:
                field_name = field_info['field_name'].iloc[0]
                # Truncate to 120 characters
                truncated_name = field_name[:120] + "..." if len(field_name) > 120 else field_name
                display_text = f"{field} - {truncated_name}"
            else:
                display_text = field

            field_options.append({'label': display_text, 'value': field})

        # Default to first field if none selected
        if selected_field is None:
            selected_field = available_fields[0] if available_fields else None

        if selected_field is None:
            return go.Figure()

        # Analyze the selected field
        field_data = self.data['pisa'][selected_field]
        field_info = self._analyze_field(selected_field, field_data)

        # Create the visualization
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=(
                f'Distribution of {selected_field}',
                'Missing Values Analysis',
                'ML Model Usage',
                'Field Statistics'
            ),
            specs=[[{"secondary_y": False}, {"type": "domain"}],
                  [{"secondary_y": False}, {"secondary_y": False}]],
            vertical_spacing=0.15,
            horizontal_spacing=0.1
        )

        # 1. Distribution plot
        if field_info['is_numeric']:
            # Histogram for numeric data
            fig.add_trace(
                go.Histogram(
                    x=field_data.dropna(),
                    nbinsx=30,
                    name='Distribution',
                    opacity=0.7
                ),
                row=1, col=1
            )
        else:
            # Bar chart for categorical data (top 20 values)
            value_counts = field_data.value_counts().head(20)
            fig.add_trace(
                go.Bar(
                    x=value_counts.index.astype(str),
                    y=value_counts.values,
                    name='Frequency',
                    text=value_counts.values,
                    textposition='auto'
                ),
                row=1, col=1
            )

        # 2. Missing values visualization
        total_rows = len(field_data)
        missing_count = field_data.isnull().sum()
        valid_count = total_rows - missing_count

        fig.add_trace(
            go.Pie(
                labels=['Valid Values', 'Missing Values'],
                values=[valid_count, missing_count],
                hole=0.4,
                name='Missing Analysis'
            ),
            row=1, col=2
        )

        # 3. ML Model usage
        model_usage = self._get_field_model_usage(selected_field)
        if model_usage:
            models = list(model_usage.keys())
            importance_scores = [model_usage[model].get('importance', 0) for model in models]

            fig.add_trace(
                go.Bar(
                    x=models,
                    y=importance_scores,
                    name='Feature Importance',
                    text=[f"{score:.3f}" for score in importance_scores],
                    textposition='auto'
                ),
                row=2, col=1
            )
        else:
            # Show "Not used in models" message
            fig.add_annotation(
                text="Field not used in ML models",
                xref="x domain", yref="y domain",
                x=0.5, y=0.5,
                xanchor="center", yanchor="middle",
                showarrow=False,
                font=dict(size=14),
                row=2, col=1
            )

        # 4. Field statistics table
        stats_text = self._format_field_statistics(field_info)
        fig.add_annotation(
            text=stats_text,
            xref="x domain", yref="y domain",
            x=0.05, y=0.95,
            xanchor="left", yanchor="top",
            showarrow=False,
            font=dict(size=11, family="monospace"),
            align="left",
            row=2, col=2
        )

        # Apply template and styling
        fig.update_layout(
            template=self.theme,
            height=800,
            showlegend=True
        )

        self.add_custom_title(fig, f"Field Explorer: {selected_field}")

        return fig, field_options

    def _analyze_field(self, field_name, field_data):
        """Analyze a single field and return comprehensive statistics."""

        analysis = {
            'field_name': field_name,
            'total_count': len(field_data),
            'missing_count': field_data.isnull().sum(),
            'missing_percentage': (field_data.isnull().sum() / len(field_data)) * 100,
            'non_missing_count': field_data.notna().sum()
        }

        # Check if numeric
        analysis['is_numeric'] = pd.api.types.is_numeric_dtype(field_data)

        if analysis['is_numeric']:
            # Numeric field analysis
            clean_data = field_data.dropna()
            analysis.update({
                'data_type': 'Numeric',
                'min_value': clean_data.min(),
                'max_value': clean_data.max(),
                'mean': clean_data.mean(),
                'median': clean_data.median(),
                'std': clean_data.std(),
                'unique_count': clean_data.nunique(),
                'range': clean_data.max() - clean_data.min()
            })
        else:
            # Categorical field analysis
            clean_data = field_data.dropna()
            value_counts = clean_data.value_counts()
            analysis.update({
                'data_type': 'Categorical',
                'unique_count': clean_data.nunique(),
                'most_common': value_counts.index[0] if len(value_counts) > 0 else None,
                'most_common_count': value_counts.iloc[0] if len(value_counts) > 0 else 0,
                'most_common_percentage': (value_counts.iloc[0] / len(clean_data)) * 100 if len(value_counts) > 0 else 0,
                'unique_values': value_counts.head(10).to_dict()  # Top 10 categories
            })

        return analysis

    def _get_field_model_usage(self, field_name):
        """Check if field is used in ML models and get its importance."""

        model_usage = {}
        subjects = ['math', 'read', 'sci']
        model_types = ['pisa_reg_lasso_model', 'pisa_rand_forest_model']

        for model_type in model_types:
            for subject in subjects:
                # Get feature importance for this model
                feature_imp = self.get_feature_importance(model_type, subject)

                if not feature_imp.empty:
                    # Check if our field is in the features
                    field_importance = feature_imp[feature_imp['feature'] == field_name]

                    if not field_importance.empty:
                        model_key = f"{model_type.replace('pisa_', '').replace('_model', '')}_{subject}"
                        importance_value = field_importance['weight'].iloc[0] if 'weight' in field_importance.columns else field_importance['abs_weight'].iloc[0]

                        model_usage[model_key] = {
                            'importance': abs(float(importance_value)),
                            'rank': len(feature_imp[feature_imp['weight'] > importance_value]) + 1 if 'weight' in feature_imp.columns else len(feature_imp[feature_imp['abs_weight'] > importance_value]) + 1,
                            'model_type': model_type,
                            'subject': subject
                        }

        return model_usage

    def _format_field_statistics(self, field_info):
        """Format field statistics as HTML text."""

        stats_lines = [
            f"<b>Field Statistics</b>",
            f"━━━━━━━━━━━━━━━━━━━━",
            f"<b>Field Name:</b> {field_info['field_name']}",
            f"<b>Data Type:</b> {field_info['data_type']}",
            f"<b>Total Records:</b> {field_info['total_count']:,}",
            f"<b>Valid Values:</b> {field_info['non_missing_count']:,}",
            f"<b>Missing Values:</b> {field_info['missing_count']:,} ({field_info['missing_percentage']:.1f}%)",
            f"<b>Unique Values:</b> {field_info['unique_count']:,}",
            ""
        ]

        if field_info['is_numeric']:
            stats_lines.extend([
                f"<b>Numeric Statistics:</b>",
                f"• Min: {field_info['min_value']:.2f}",
                f"• Max: {field_info['max_value']:.2f}",
                f"• Mean: {field_info['mean']:.2f}",
                f"• Median: {field_info['median']:.2f}",
                f"• Std Dev: {field_info['std']:.2f}",
                f"• Range: {field_info['range']:.2f}"
            ])
        else:
            stats_lines.extend([
                f"<b>Categorical Statistics:</b>",
                f"• Most Common: {field_info['most_common']}",
                f"• Frequency: {field_info['most_common_count']:,} ({field_info['most_common_percentage']:.1f}%)",
                "",
                f"<b>Top Categories:</b>"
            ])

            for category, count in list(field_info['unique_values'].items())[:5]:
                percentage = (count / field_info['non_missing_count']) * 100
                stats_lines.append(f"• {str(category)[:20]}: {count:,} ({percentage:.1f}%)")

        return "<br>".join(stats_lines)

    def create_model_comparison(self):
        """Create model performance comparison."""

        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('R² Comparison', 'Best Trial Performance',
                          'Training vs Validation Loss', 'Model Summary'),
            specs=[[{"secondary_y": False}, {"secondary_y": False}],
                  [{"secondary_y": False}, {"type": "table"}]],
            vertical_spacing=0.12,
            horizontal_spacing=0.1
        )

        # Prepare data
        model_comparison_data = []
        subjects = ['math', 'read', 'sci']
        model_types = ['pisa_reg_lasso_model', 'pisa_rand_forest_model']

        for model_type in model_types:
            for subject in subjects:
                if subject in self.model_results[model_type] and not self.model_results[model_type][subject].empty:
                    best_trial = self.model_results[model_type][subject].iloc[0]
                    model_comparison_data.append({
                        'model_type': model_type.replace('pisa_', '').replace('_model', ''),
                        'subject': subject.upper(),
                        'r2_score': best_trial['r2_score'],
                        'eval_loss': best_trial['eval_loss'],
                        'training_loss': best_trial['training_loss']
                    })

        comparison_df = pd.DataFrame(model_comparison_data)

        if not comparison_df.empty:
            # 1. R² Comparison
            for model_type in comparison_df['model_type'].unique():
                model_data = comparison_df[comparison_df['model_type'] == model_type]

                if 'reg' in model_type.lower() or 'lasso' in model_type.lower():
                    bar_color = '#60a5fa'  # Blue for linear regression
                else:
                    bar_color = '#34d399'  # Green for random forest

                fig.add_trace(
                    go.Bar(
                        x=model_data['subject'],
                        y=model_data['r2_score'],
                        name=f'{model_type.replace("_", " ").title()}',
                        text=model_data['r2_score'].round(3),
                        textposition='auto',
                        marker_color = bar_color
                    ),
                    row=1, col=1
                )

            # 2. Best performance by subject
            best_performance = comparison_df.loc[comparison_df.groupby('subject')['r2_score'].idxmax()]
            fig.add_trace(
                go.Scatter(
                    x=best_performance['subject'],
                    y=best_performance['r2_score'],
                    mode='markers+text',
                    marker=dict(size=20, line=dict(width=2)),
                    text=best_performance['model_type'],
                    textposition='top center',
                    name='Best Model',
                    textfont=dict(size=12),
                    marker_color='#00f737'
                ),
                row=1, col=2
            )

            # 3. Training vs Validation Loss
            fig.add_trace(
                go.Scatter(
                    x=comparison_df['training_loss'],
                    y=comparison_df['eval_loss'],
                    mode='markers+text',
                    text=comparison_df['subject'],
                    marker=dict(
                        size=15,
                        color=comparison_df['r2_score'],
                        colorscale='Viridis',
                        showscale=True,
                        colorbar=dict(
                            title="R² Score",
                            x=0.45,
                            len=0.5,
                            y = 0.25,
                            thickness=15
                            )
                    ),
                    name='Models',
                    textposition='top center',
                    textfont=dict(size=10)
                ),
                row=2, col=1
            )

            # Add diagonal line for reference
            max_loss = max(comparison_df['training_loss'].max(), comparison_df['eval_loss'].max())
            fig.add_trace(
                go.Scatter(
                    x=[0, max_loss],
                    y=[0, max_loss],
                    mode='lines',
                    line=dict(dash='dash', width=2),
                    name='Perfect Fit',
                    showlegend=False
                ),
                row=2, col=1
            )

            # 4. Model summary table
            # Create table data
            table_headers = ['Model', 'Subject', 'R² (Explained Variance)', 'Val Loss']
            table_cells = []

            for _, row in comparison_df.iterrows():
                table_cells.append([
                    row['model_type'].replace('_', ' ').title(),
                    row['subject'].upper(),
                    f"{row['r2_score']:.3f}",
                    f"{row['eval_loss']:.1f}"
                ])

            # Transpose for plotly table format
            table_data = list(map(list, zip(*table_cells))) if table_cells else [[], [], [], []]

            fig.add_trace(
                go.Table(
                    header=dict(
                        values=table_headers,
                        fill_color='#40466e',
                        font=dict(color='white', size=12),
                        align='center',
                        height=30
                    ),
                    cells=dict(
                        values=table_data,
                        fill_color='#f1f1f2',
                        font=dict(color='#506784', size=11),
                        align='center',
                        height=25
                    ),
                    domain=dict(x=[0, 1], y=[0, 0.85])
                ),
                row=2, col=2
            )


        # Apply template and add custom title
        fig.update_layout(template=self.theme, height=700, showlegend=True)
        self.add_custom_title(fig, "Model Performance Comparison")

        return fig

    def create_feature_importance_dashboard(self):
        """Create feature importance comparison dashboard."""
        fig = make_subplots(
            rows=3, cols=2,
            subplot_titles=('Linear Regression - Math', 'Random Forest - Math',
                          'Linear Regression - Read', 'Random Forest - Read',
                          'Linear Regression - Science', 'Random Forest - Science'),
            specs=[[{"secondary_y": False}, {"secondary_y": False}],
                  [{"secondary_y": False}, {"secondary_y": False}],
                  [{"secondary_y": False}, {"secondary_y": False}]],
                  vertical_spacing=0.08,
                  horizontal_spacing=0.1
        )

        subjects = ['math', 'read', 'sci']
        model_types = ['pisa_reg_lasso_model', 'pisa_rand_forest_model']

        for i, subject in enumerate(subjects):
            row = i + 1

            for j, model_type in enumerate(model_types):
                col = j + 1

                # Get feature importance
                feature_imp = self.get_feature_importance(model_type, subject)

                if not feature_imp.empty:
                    # Get top 10 features
                    top_features = feature_imp.head(10)

                    fig.add_trace(
                        go.Bar(
                            y=top_features['feature'][::-1],  # Just use the feature column directly
                            x=top_features['weight'][::-1] if 'weight' in top_features.columns else top_features['abs_weight'][::-1],
                            orientation='h',
                            name=f'{model_type.split("_")[1]} - {subject}',
                            text=top_features['weight'][::-1].round(3) if 'weight' in top_features.columns else top_features['abs_weight'][::-1].round(3),
                            textposition='auto',
                            cliponaxis=False
                        ),
                        row=row, col=col
                    )

        # Apply template and add custom title
        fig.update_layout(template=self.theme, height=1400, showlegend=False)
        self.add_custom_title(fig, "Feature Importance Comparison Across Models")
        fig.layout.annotations[-1].y = 1.05  # Move title down from 1.1 to 1.02
        fig.update_layout(margin=dict(t=120, b=60, l=200, r=60)) # Override the margins

        return fig

    def create_country_explorer(self):
        """Create an interactive world map explorer by country."""

        # Check if CNT column exists
        if 'CNT' not in self.data['pisa'].columns:
            print("⚠️ Country (CNT) column not found in data")
            return go.Figure().add_annotation(
                text="Country data not available",
                xref="paper", yref="paper",
                x=0.5, y=0.5, showarrow=False,
                font=dict(size=16)
            )

        # Complete country code mapping for your dataset
        country_mapping = {
            'ALB': 'Albania',
            'QAZ': 'Azerbaijan',
            'AUS': 'Australia',
            'BRA': 'Brazil',
            'BGR': 'Bulgaria',
            'KHM': 'Cambodia',
            'CAN': 'Canada',
            'DOM': 'Dominican Republic',
            'SLV': 'El Salvador',
            'FRA': 'France',
            'GEO': 'Georgia',
            'DEU': 'Germany',
            'GTM': 'Guatemala',
            'HUN': 'Hungary',
            'ISR': 'Israel',
            'ITA': 'Italy',
            'KSV': 'Kosovo',
            'JAM': 'Jamaica',
            'LVA': 'Latvia',
            'MAR': 'Morocco',
            'NLD': 'Netherlands',
            'PAN': 'Panama',
            'PRY': 'Paraguay',
            'PRT': 'Portugal',
            'SAU': 'Saudi Arabia',
            'SVK': 'Slovakia',
            'VNM': 'Vietnam',
            'ESP': 'Spain',
            'MKD': 'North Macedonia',
            'GBR': 'United Kingdom',
            'AUT': 'Austria',
            'BEL': 'Belgium',
            'CHL': 'Chile',
            'COL': 'Colombia',
            'CZE': 'Czech Republic',
            'DNK': 'Denmark',
            'GRC': 'Greece',
            'IDN': 'Indonesia',
            'IRL': 'Ireland',
            'JOR': 'Jordan',
            'LBN': 'Lebanon',
            'LUX': 'Luxembourg',
            'MDA': 'Moldova',
            'NOR': 'Norway',
            'SWE': 'Sweden',
            'ARE': 'United Arab Emirates',
            'UKR': 'Ukraine',
            'USA': 'United States',
            'QCH': 'China',
            'CRI': 'Costa Rica',
            'DZA': 'Algeria',
            'EST': 'Estonia',
            'HRV': 'Croatia',
            'MEX': 'Mexico',
            'PER': 'Peru',
            'POL': 'Poland',
            'QES': 'Spain',
            'RUS': 'Russia',
            'SVN': 'Slovenia',
            'THA': 'Thailand',
            'TUN': 'Tunisia',
            'TUR': 'Turkey',
            'URY': 'Uruguay',
            'ARG': 'Argentina',
            'TAP': 'Chinese Taipei',
            'FIN': 'Finland',
            'HKG': 'Hong Kong',
            'ISL': 'Iceland',
            'JPN': 'Japan',
            'KOR': 'South Korea',
            'LTU': 'Lithuania',
            'MLT': 'Malta',
            'NZL': 'New Zealand',
            'ROU': 'Romania',
            'CHE': 'Switzerland',
            'SRB': 'Serbia',
            'QAR': 'Qatar',
            'QAT': 'Qatar',
            'SGP': 'Singapore',
            'TTO': 'Trinidad and Tobago',
            'KAZ': 'Kazakhstan',
            'MNE': 'Montenegro',
            'PSE': 'Palestine',
            'MYS': 'Malaysia',
            'MNG': 'Mongolia',
            'PHL': 'Philippines',
            'UZB': 'Uzbekistan',
            'QUR': 'Uruguay',
            'BIH': 'Bosnia and Herzegovina',
            'BRN': 'Brunei',
            'BLR': 'Belarus',
            'QCI': 'China',
            'QMR': 'Morocco',
            'QRT': 'Russia',
            'MAC': 'Macao'
        }

        # Aggregate country statistics
        country_stats = []

        for country_code in self.data['pisa']['CNT'].unique():
            if pd.isna(country_code):
                continue

            # Map country code to full name
            country_name = country_mapping.get(country_code, country_code)

            # Skip if mapping failed (will show as unmapped)
            if country_name == country_code and country_code not in country_mapping:
                print(f"⚠️ No mapping found for country code: {country_code}")

            country_data = self.data['pisa'][self.data['pisa']['CNT'] == country_code]

            # Calculate statistics for each subject
            stats = {
                'country_code': country_code,
                'country_name': country_name,
                'total_students': len(country_data)
            }

            # Count unique schools if school ID column exists
            if 'CNTSCHID' in country_data.columns:
                stats['total_schools'] = country_data['CNTSCHID'].nunique()
            else:
                # If no school ID column found, estimate or set to unknown
                stats['total_schools'] = 'Unknown'
                print(f"⚠️ No school ID column found for {country_code}")

            for subject in ['math', 'read', 'sci']:
                subject_data = country_data[subject].dropna()
                if len(subject_data) > 0:
                    stats.update({
                        f'{subject}_mean': subject_data.mean(),
                        f'{subject}_median': subject_data.median(),
                        f'{subject}_min': subject_data.min(),
                        f'{subject}_max': subject_data.max(),
                        f'{subject}_std': subject_data.std(),
                        f'{subject}_count': len(subject_data)
                    })
                else:
                    stats.update({
                        f'{subject}_mean': 0, f'{subject}_median': 0,
                        f'{subject}_min': 0, f'{subject}_max': 0,
                        f'{subject}_std': 0, f'{subject}_count': 0
                    })

            country_stats.append(stats)

        country_df = pd.DataFrame(country_stats)

        if country_df.empty:
            return go.Figure().add_annotation(
                text="No country data available",
                xref="paper", yref="paper",
                x=0.5, y=0.5, showarrow=False
            )

        print(f"Processed {len(country_df)} countries")

        # Create hover text with comprehensive country information
        hover_texts = []
        for _, row in country_df.iterrows():
            hover_text = f"""
            <b>{row['country_name']} ({row['country_code']})</b><br>
            <b>Sample Info:</b><br>
            • Students Tested: {row['total_students']:,}<br>
            <br>
            <b>Mathematics:</b><br>
            • Average: {row['math_mean']:.1f}<br>
            • Median: {row['math_median']:.1f}<br>
            • Range: {row['math_min']:.0f} - {row['math_max']:.0f}<br>
            • Std Dev: {row['math_std']:.1f}<br>
            <br>
            <b>Reading:</b><br>
            • Average: {row['read_mean']:.1f}<br>
            • Median: {row['read_median']:.1f}<br>
            • Range: {row['read_min']:.0f} - {row['read_max']:.0f}<br>
            • Std Dev: {row['read_std']:.1f}<br>
            <br>
            <b>Science:</b><br>
            • Average: {row['sci_mean']:.1f}<br>
            • Median: {row['sci_median']:.1f}<br>
            • Range: {row['sci_min']:.0f} - {row['sci_max']:.0f}<br>
            • Std Dev: {row['sci_std']:.1f}
            """
            hover_texts.append(hover_text)

        # Calculate overall average score for color mapping
        country_df['overall_avg'] = (country_df['math_mean'] + country_df['read_mean'] + country_df['sci_mean']) / 3

        # Create the choropleth map
        fig = go.Figure()

        fig.add_trace(
            go.Choropleth(
                locations=country_df['country_name'],
                z=country_df['overall_avg'],
                locationmode='country names',
                colorscale='RdYlGn',  # Red-Yellow-Green: Red for low, Green for high
                reversescale=False,   # Don't reverse: Red=low, Green=high
                colorbar=dict(
                    title="Overall Average Score",
                    titlefont=dict(size=14),
                    tickfont=dict(size=12)
                ),
                hovertemplate='%{hovertext}<extra></extra>',
                hovertext=hover_texts,
                marker=dict(
                    line=dict(
                        color='rgb(180,180,180)',
                        width=0.5
                    )
                )
            )
        )

        # Update layout for better appearance
        fig.update_layout(
            template=self.theme,
            height=700,
            plot_bgcolor='rgb(0, 0, 0)',
            geo=dict(
                showframe=False,
                showcoastlines=True,
                showland=True,
                landcolor='rgb(243, 243, 243)',
                coastlinecolor='rgb(204, 204, 204)',
                projection_type='natural earth',
                bgcolor='rgb(173, 216, 230)'
            )
        )

        # Add custom title
        self.add_custom_title(fig, "PISA Global Performance Explorer")

        return fig


    def create_interactive_explorer(self):
        """Create an interactive data explorer."""
        # Create dropdown options for features
        numeric_columns = self.data['pisa'].select_dtypes(include=[np.number]).columns.tolist()
        feature_options = [col for col in numeric_columns if col not in ['math', 'read', 'sci']][:20]  # Limit for performance

        # Initial scatter plot
        fig = go.Figure()

        # Add scatter plot for each subject
        for subject in ['math', 'read', 'sci']:
            fig.add_trace(
                go.Scatter(
                    x=self.data['pisa'][feature_options[0]] if feature_options else [],
                    y=self.data['pisa'][subject],
                    mode='markers',
                    name=f'{subject.upper()} Scores',
                    opacity=0.6,
                    marker=dict(size=5)
                )
            )

        # Apply template and add custom title
        fig.update_layout(
            template=self.theme,
            xaxis_title=feature_options[0] if feature_options else "Feature",
            yaxis_title="Scores",
            height=600
        )
        self.add_custom_title(fig, "Interactive Data Explorer")

        return fig, feature_options

    def create_prediction_analysis(self):
        """Create prediction vs actual analysis."""
        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=('Math Predictions', 'Reading Predictions', 'Science Predictions')
        )

        subjects = ['math', 'read', 'sci']

        for i, subject in enumerate(subjects):
            # Get predictions from best Random Forest model (you may need to modify this query)
            try:
                pred_query = f"""
                SELECT
                    {subject},
                    predicted_{subject} as predicted
                FROM ML.PREDICT(
                    MODEL `{self.project_id}.{self.dataset}.pisa_rand_forest_model_{subject}`,
                    (SELECT * FROM `{self.project_id}.{self.dataset}.pisa_data`
                     WHERE {subject} IS NOT NULL LIMIT 1000)
                )
                """

                pred_results = self.client.query(pred_query).to_dataframe()

                if not pred_results.empty:
                    fig.add_trace(
                        go.Scatter(
                            x=pred_results[subject],
                            y=pred_results['predicted'],
                            mode='markers',
                            name=f'{subject.upper()}',
                            opacity=0.6,
                            marker=dict(size=5)
                        ),
                        row=1, col=i+1
                    )

                    # Add perfect prediction line
                    min_val = min(pred_results[subject].min(), pred_results['predicted'].min())
                    max_val = max(pred_results[subject].max(), pred_results['predicted'].max())

                    fig.add_trace(
                        go.Scatter(
                            x=[min_val, max_val],
                            y=[min_val, max_val],
                            mode='lines',
                            line=dict(dash='dash', width=2),
                            name='Perfect Prediction',
                            showlegend=True if i == 0 else False
                        ),
                        row=1, col=i+1
                    )

            except Exception as e:
                print(f"Could not load predictions for {subject}: {e}")

        # Apply template and add custom title
        fig.update_layout(template=self.theme, height=500, showlegend=True)
        self.add_custom_title(fig, "Actual vs Predicted Scores")

        return fig


def create_field_explorer_with_dropdown(dashboard_instance):
    """Create an interactive field explorer with dropdown widget."""
    import ipywidgets as widgets
    from IPython.display import display, clear_output
    import plotly.offline as pyo

    # Enable offline mode for better Colab compatibility
    pyo.init_notebook_mode(connected=True)

    # Get all available fields from the PISA data
    available_fields = list(dashboard_instance.data['pisa'].columns)

    # Create dropdown options with field names from codebooks
    dropdown_options = []
    for field in available_fields[:30]:  # Limit for performance
        field_info = dashboard_instance.data['codebooks'][dashboard_instance.data['codebooks']['field_id'] == field]
        if not field_info.empty:
            field_name = field_info['field_name'].iloc[0]
            truncated_name = field_name[:60] + "..." if len(field_name) > 60 else field_name
            display_text = f"{field} - {truncated_name}"
        else:
            display_text = field
        dropdown_options.append((display_text, field))

    # Create dropdown widget
    dropdown = widgets.Dropdown(
        options=dropdown_options,
        value=available_fields[0],
        description='Field:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='600px')
    )

    # Create output widget for the plot
    output = widgets.Output()

    def update_plot(change):
        """Update the plot when dropdown selection changes."""
        with output:
            clear_output(wait=True)
            selected_field = change['new']
            print(f"🔍 Exploring: {selected_field}")

            try:
                fig, _ = dashboard_instance.create_field_explorer(selected_field)
                # Force display using iplot instead of show()
                pyo.iplot(fig)
            except Exception as e:
                print(f"Error: {e}")

    # Set up the callback
    dropdown.observe(update_plot, names='value')

    # Display widgets first
    display(widgets.VBox([dropdown, output]))

    # Display initial plot
    with output:
        print(f"🔍 Initial field: {dropdown.value}")
        fig, _ = dashboard_instance.create_field_explorer(dropdown.value)
        pyo.iplot(fig)

    return dropdown, output


# Alternative: Simple function to explore specific fields
def explore_field(self, field_name):
    """Explore a specific field by name."""
    fig, _ = self.create_field_explorer(field_name)
    fig.show()

    # Also show available fields for reference
    print(f"\n📋 Available fields (showing first 10):")
    for i, field in enumerate(list(self.data['pisa'].columns)[:10]):
        field_info = self.data['codebooks'][self.data['codebooks']['field_id'] == field]
        description = field_info['field_name'].iloc[0] if not field_info.empty else "No description"
        print(f"{i+1:2d}. {field} - {description[:80]}...")

    print(f"\n💡 Usage: dashboard.explore_field('FIELD_NAME')")

def add_interactive_methods(dashboard_instance):
    """Add interactive methods to existing dashboard instance."""
    dashboard_instance.create_field_explorer_with_dropdown = create_field_explorer_with_dropdown.__get__(dashboard_instance)
    dashboard_instance.explore_field = explore_field.__get__(dashboard_instance)

Load data

In [None]:
pisa_data, codebooks_data = load_pisa_data(PROJECT_ID, BQ_DATASET, client)
model_results = load_model_results(PROJECT_ID, BQ_DATASET, client)

📊 Loading PISA data...
✅ Loaded 54080 PISA records
✅ Loaded 75 codebook entries
🤖 Loading model results...
✅ Loaded best trial for pisa_reg_lasso_model_math (R² = 0.319)
✅ Loaded best trial for pisa_reg_lasso_model_read (R² = 0.251)
✅ Loaded best trial for pisa_reg_lasso_model_sci (R² = 0.386)
✅ Loaded best trial for pisa_rand_forest_model_math (R² = 0.440)
✅ Loaded best trial for pisa_rand_forest_model_read (R² = 0.384)
✅ Loaded best trial for pisa_rand_forest_model_sci (R² = 0.396)


### Instantiate dashboard

In [None]:
dashboard = PISADashboard(PROJECT_ID, BQ_DATASET, client, pisa_data, codebooks_data, model_results, theme = 'dark_modern')

🎯 Dashboard initialized with 54080 PISA records
🤖 Dashboard has model results for 2 model types


### 1 - Data Overview

In [None]:
print("📊 Creating data overview...")
overview_fig = dashboard.create_data_overview()
overview_fig.show()

📊 Creating data overview...


### 2 - Country Explorer

In [None]:
print("🔍 Creating country explorer...")
explorer_fig = dashboard.create_country_explorer()
explorer_fig.show()

🔍 Creating country explorer...
Processed 95 countries


###3 - Model Comparison

In [None]:
print("🤖 Creating model comparison...")
model_fig = dashboard.create_model_comparison()
model_fig.show()

🤖 Creating model comparison...


### 4 - Feature Importance

In [None]:
print("📈 Creating feature importance analysis...")
feature_fig = dashboard.create_feature_importance_dashboard()
feature_fig.show()

📈 Creating feature importance analysis...


###5 - Field Explorer

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Get available fields
available_fields = list(dashboard.data['pisa'].columns)

# Create dropdown options with descriptions
dropdown_options = []
for field in available_fields[:30]:  # Limit for performance
    field_info = dashboard.data['codebooks'][dashboard.data['codebooks']['field_id'] == field]
    if not field_info.empty:
        field_name = field_info['field_name'].iloc[0]
        truncated_name = field_name[:60] + "..." if len(field_name) > 60 else field_name
        display_text = f"{field} - {truncated_name}"
    else:
        display_text = field
    dropdown_options.append((display_text, field))

# Initialize the variable
CNTRY_CODE = 'CNT'  # Default value

# Create dropdown
field_dropdown = widgets.Dropdown(
    options=dropdown_options,
    value=CNTRY_CODE,
    description='Select Field:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

def on_dropdown_change(change):
    global CNTRY_CODE
    CNTRY_CODE = change['new']
    print(f"✅ Selected field: {CNTRY_CODE}")

# Connect the dropdown to update the variable
field_dropdown.observe(on_dropdown_change, names='value')

# Display the dropdown
display(field_dropdown)
print(f"Current selection: {CNTRY_CODE}")
print("👆 Select a field from the dropdown above, then run the next cell to explore it")

Dropdown(description='Select Field:', layout=Layout(width='600px'), options=(('CNT - Country code 3-character'…

Current selection: CNT
👆 Select a field from the dropdown above, then run the next cell to explore it
✅ Selected field: SC018Q01TA01


In [None]:
pio.renderers.default = "colab"
fig, _ = dashboard.create_field_explorer(CNTRY_CODE)
fig.show()

In [None]:
# Display summary statistics
print("\n" + "="*60)
print("DASHBOARD SUMMARY")
print("="*60)
print(f"📊 Dataset: {len(dashboard.data['pisa'])} PISA records")
print(f"📚 Codebook: {len(dashboard.data['codebooks'])} field definitions")
print(f"🤖 Models analyzed: Linear Regression + Random Forest")
print(f"🎯 Subjects: Math, Reading, Science")
print(f"🎨 Theme: {dashboard.theme}")
print("="*60)


DASHBOARD SUMMARY
📊 Dataset: 54080 PISA records
📚 Codebook: 75 field definitions
🤖 Models analyzed: Linear Regression + Random Forest
🎯 Subjects: Math, Reading, Science
🎨 Theme: dark_modern
