In [None]:
# =============================================================================
# PSYCHOLOGICAL MECHANISMS MODEL
# Uninformative priors, no control baseline bias
# =============================================================================

import pymc as pm
import arviz as az
import numpy as np
import pandas as pd

def create_mechanisms_model(df, concept_names, mechanisms=None):
    """
    Model with psychological mechanisms as intermediaries
    All concepts (including control) get effects, uninformative priors
    
    Structure:
    concept → mechanisms → {cheating, performance, experience}
    concept → {cheating, performance, experience} (direct)
    """
    
    if mechanisms is None:
        mechanisms = [
            'PME_on_honest_task_completion', 'PME_on_task_performance', 'PME_on_task_experience',
            'autonomy_need_satisfaction', 'autonomy_need_frustration', 
            'competence_need_satisfaction', 'competence_need_frustration',
            'relatedness_need_satisfaction', 'relatedness_need_frustration',
            'performance_accomplishments', 'vicarious_experience', 'verbal_persuasion', 'emotional_arousal',
            'injunctive_norms', 'descriptive_norms', 'reference_group_identification', 'social_sanctions',
            'cognitive_discomfort', 'moral_disengagement', 'perceived_ability'
        ]
    """
    Model with psychological mechanisms as intermediaries
    All concepts (including control) get effects, uninformative priors
    
    Structure:
    concept → mechanisms → {cheating, performance, experience}
    concept → {cheating, performance, experience} (direct)
    """
    
    n_concepts = len(concept_names)
    n_mechanisms = len(mechanisms)
    
    with pm.Model() as mechanisms_model:
        
        # =============================================================================
        # CONCEPT → MECHANISMS PATHWAY
        # =============================================================================
        
        # Uninformative hierarchical priors
        mu_mech_baseline = pm.Normal('mu_mech_baseline', mu=4, sigma=2, shape=n_mechanisms)
        sigma_mech_baseline = pm.HalfNormal('sigma_mech_baseline', sigma=1, shape=n_mechanisms)
        
        mu_concept_to_mech = pm.Normal('mu_concept_to_mech', mu=0, sigma=0.5, shape=n_mechanisms)
        sigma_concept_to_mech = pm.HalfNormal('sigma_concept_to_mech', sigma=0.5, shape=n_mechanisms)
        
        # Mechanism baselines
        mechanism_baselines = pm.Normal('mechanism_baselines', 
                                       mu=mu_mech_baseline,
                                       sigma=sigma_mech_baseline,
                                       shape=n_mechanisms)
        
        # Concept effects on mechanisms
        concept_to_mechanisms = pm.Normal('concept_to_mechanisms',
                                        mu=mu_concept_to_mech,
                                        sigma=sigma_concept_to_mech,
                                        shape=(n_concepts, n_mechanisms))
        
        # Predicted mechanism values
        predicted_mechanisms = {}
        for i, mech in enumerate(mechanisms):
            if mech in df.columns:
                mu_mech = (mechanism_baselines[i] + 
                          concept_to_mechanisms[df['concept_idx'].values, i])
                
                sigma_mech = pm.HalfNormal(f'sigma_{mech}', sigma=1)
                predicted_mechanisms[mech] = pm.TruncatedNormal(f'{mech}_observed',
                                                              mu=mu_mech,
                                                              sigma=sigma_mech,
                                                              lower=1, upper=7,
                                                              observed=df[mech].values)
        
        # =============================================================================
        # MECHANISMS → OUTCOMES PATHWAY
        # =============================================================================
        
        # Uninformative priors for mechanism effects
        mechanism_to_partial = pm.Normal('mechanism_to_partial', mu=0, sigma=1, 
                                       shape=n_mechanisms)
        mechanism_to_full = pm.Normal('mechanism_to_full', mu=0, sigma=1,
                                    shape=n_mechanisms)
        mechanism_to_perf = pm.Normal('mechanism_to_perf', mu=0, sigma=2,
                                    shape=(n_mechanisms, 3))
        mechanism_to_exp = pm.Normal('mechanism_to_exp', mu=0, sigma=0.5,
                                   shape=(n_mechanisms, 3))
        
        # =============================================================================
        # DIRECT CONCEPT → OUTCOMES (RESIDUAL EFFECTS)
        # =============================================================================
        
        # Hierarchical priors for direct effects
        mu_direct_partial = pm.Normal('mu_direct_partial', mu=0, sigma=1)
        sigma_direct_partial = pm.HalfNormal('sigma_direct_partial', sigma=1)
        
        mu_direct_full = pm.Normal('mu_direct_full', mu=0, sigma=1)
        sigma_direct_full = pm.HalfNormal('sigma_direct_full', sigma=1)
        
        mu_direct_perf = pm.Normal('mu_direct_perf', mu=0, sigma=2, shape=3)
        sigma_direct_perf = pm.HalfNormal('sigma_direct_perf', sigma=2, shape=3)
        
        mu_direct_exp = pm.Normal('mu_direct_exp', mu=0, sigma=0.5, shape=3)
        sigma_direct_exp = pm.HalfNormal('sigma_direct_exp', sigma=0.5, shape=3)
        
        # Direct concept effects
        direct_concept_partial = pm.Normal('direct_concept_partial', 
                                         mu=mu_direct_partial,
                                         sigma=sigma_direct_partial,
                                         shape=n_concepts)
        direct_concept_full = pm.Normal('direct_concept_full',
                                      mu=mu_direct_full,
                                      sigma=sigma_direct_full,
                                      shape=n_concepts)
        direct_concept_perf = pm.Normal('direct_concept_perf',
                                      mu=mu_direct_perf,
                                      sigma=sigma_direct_perf,
                                      shape=(n_concepts, 3))
        direct_concept_exp = pm.Normal('direct_concept_exp',
                                     mu=mu_direct_exp,
                                     sigma=sigma_direct_exp,
                                     shape=(n_concepts, 3))
        
        # =============================================================================
        # OUTCOME MODELS
        # =============================================================================
        
        # Uninformative intercepts
        beta_partial_intercept = pm.Normal('beta_partial_intercept', mu=0, sigma=2)
        beta_full_intercept = pm.Normal('beta_full_intercept', mu=0, sigma=2)
        
        # Mechanism contributions to cheating
        mechanism_contrib_partial = 0
        mechanism_contrib_full = 0
        for i, mech in enumerate(mechanisms):
            if mech in df.columns:
                mechanism_contrib_partial += mechanism_to_partial[i] * df[mech].values
                mechanism_contrib_full += mechanism_to_full[i] * df[mech].values
        
        # Total effects on cheating
        eta_partial = (beta_partial_intercept + 
                      mechanism_contrib_partial + 
                      direct_concept_partial[df['concept_idx'].values])
        eta_full = (beta_full_intercept + 
                   mechanism_contrib_full + 
                   direct_concept_full[df['concept_idx'].values])
        
        logits = pm.math.stack([pm.math.zeros(eta_partial.shape), eta_partial, eta_full], axis=1)
        cheating_probs = pm.math.softmax(logits, axis=1)
        
        cheating_obs = pm.Categorical('cheating_observed', p=cheating_probs, 
                                    observed=df['cheating_behavior'].values)
        
        # Performance model
        mu_perf_baseline = pm.Normal('mu_perf_baseline', mu=20, sigma=10, shape=3)
        
        mechanism_contrib_perf = 0
        for i, mech in enumerate(mechanisms):
            if mech in df.columns:
                mechanism_contrib_perf += (mechanism_to_perf[i, df['cheating_behavior'].values] * 
                                         df[mech].values)
        
        mu_perf = (mu_perf_baseline[df['cheating_behavior'].values] + 
                  mechanism_contrib_perf + 
                  direct_concept_perf[df['concept_idx'].values, df['cheating_behavior'].values])
        
        sigma_perf = pm.HalfNormal('sigma_perf', sigma=5)
        
        perf_obs = pm.TruncatedNormal('performance_observed',
                                    mu=mu_perf, sigma=sigma_perf,
                                    lower=0, upper=100,
                                    observed=df['performance'].values)
        
        # Experience model
        mu_exp_baseline = pm.Normal('mu_exp_baseline', mu=4, sigma=2, shape=3)
        
        mechanism_contrib_exp = 0
        for i, mech in enumerate(mechanisms):
            if mech in df.columns:
                mechanism_contrib_exp += (mechanism_to_exp[i, df['cheating_behavior'].values] * 
                                        df[mech].values)
        
        mu_exp = (mu_exp_baseline[df['cheating_behavior'].values] + 
                 mechanism_contrib_exp + 
                 direct_concept_exp[df['concept_idx'].values, df['cheating_behavior'].values])
        
        sigma_exp = pm.HalfNormal('sigma_exp', sigma=1)
        
        exp_obs = pm.TruncatedNormal('experience_observed',
                                   mu=mu_exp, sigma=sigma_exp,
                                   lower=1, upper=7,
                                   observed=df['experience'].values)
    
    return mechanisms_model

def calculate_pathway_effects(trace, concept_names, mechanisms, control_idx=0):
    """Calculate total, direct, and indirect effects relative to control"""
    
    n_concepts = len(concept_names)
    n_mechanisms = len(mechanisms)
    
    # Extract parameters
    concept_to_mech = trace.posterior['concept_to_mechanisms'].values.reshape(-1, n_concepts, n_mechanisms)
    mech_to_partial = trace.posterior['mechanism_to_partial'].values.reshape(-1, n_mechanisms)
    mech_to_full = trace.posterior['mechanism_to_full'].values.reshape(-1, n_mechanisms)
    mech_to_perf = trace.posterior['mechanism_to_perf'].values.reshape(-1, n_mechanisms, 3)
    mech_to_exp = trace.posterior['mechanism_to_exp'].values.reshape(-1, n_mechanisms, 3)
    
    direct_partial = trace.posterior['direct_concept_partial'].values.reshape(-1, n_concepts)
    direct_full = trace.posterior['direct_concept_full'].values.reshape(-1, n_concepts)
    direct_perf = trace.posterior['direct_concept_perf'].values.reshape(-1, n_concepts, 3)
    direct_exp = trace.posterior['direct_concept_exp'].values.reshape(-1, n_concepts, 3)
    
    results = {}
    
    for i, concept in enumerate(concept_names):
        if i != control_idx:
            results[concept] = {}
            
            # Cheating effects (vs control)
            for outcome, mech_effect in [('partial', mech_to_partial), ('full', mech_to_full)]:
                # Indirect: concept → mechanisms → outcome
                indirect = np.sum((concept_to_mech[:, i] - concept_to_mech[:, control_idx])[:, :, None] * 
                                mech_effect[:, :, None], axis=1).flatten()
                
                # Direct: concept → outcome
                direct_effect = direct_partial if outcome == 'partial' else direct_full
                direct = direct_effect[:, i] - direct_effect[:, control_idx]
                
                total = indirect + direct
                
                results[concept][f'cheating_{outcome}'] = {
                    'total_mean': total.mean(),
                    'total_hdi': az.hdi(total, hdi_prob=0.95),
                    'direct_mean': direct.mean(),
                    'direct_hdi': az.hdi(direct, hdi_prob=0.95),
                    'indirect_mean': indirect.mean(),
                    'indirect_hdi': az.hdi(indirect, hdi_prob=0.95),
                    'prop_mediated': (indirect / (total + 1e-10)).mean()
                }
            
            # Performance and experience effects by cheating group
            for outcome, mech_effect, direct_effect in [('performance', mech_to_perf, direct_perf),
                                                       ('experience', mech_to_exp, direct_exp)]:
                results[concept][outcome] = {}
                
                for cheat_group in [0, 1, 2]:
                    # Indirect
                    indirect = np.sum((concept_to_mech[:, i] - concept_to_mech[:, control_idx]) * 
                                    mech_effect[:, :, cheat_group], axis=1)
                    
                    # Direct
                    direct = (direct_effect[:, i, cheat_group] - 
                            direct_effect[:, control_idx, cheat_group])
                    
                    total = indirect + direct
                    
                    results[concept][outcome][f'group_{cheat_group}'] = {
                        'total_mean': total.mean(),
                        'total_hdi': az.hdi(total, hdi_prob=0.95),
                        'direct_mean': direct.mean(),
                        'direct_hdi': az.hdi(direct, hdi_prob=0.95),
                        'indirect_mean': indirect.mean(),
                        'indirect_hdi': az.hdi(indirect, hdi_prob=0.95),
                        'prop_mediated': (indirect / (total + 1e-10)).mean()
                    }
    
    return results

# Usage:
# mechanisms = ['need_satisfaction', 'need_frustration', 'self_efficacy', 'norm_perception', 'cognitive_discomfort']
# concept_names = df['concept'].unique() 
# model = create_mechanisms_model(df, concept_names, mechanisms)
# with model:
#     trace = pm.sample(2000, tune=1000, chains=4, target_accept=0.9)
# 
# pathway_results = calculate_pathway_effects(trace, concept_names, mechanisms)

def print_pathway_results(results):
    """Print pathway decomposition results for all concepts vs control"""
    
    print("PATHWAY ANALYSIS: CONCEPT vs CONTROL EFFECTS")
    print("="*90)
    print("Total = Direct + Indirect (via mechanisms)")
    
    # Cheating effects
    print("\nCHEATING - PARTIAL vs NON:")
    print(f"{'Concept':<20} {'Total':<20} {'Direct':<20} {'Indirect':<20} {'% Mediated':<10}")
    print("-"*90)
    
    for concept, effects in results.items():
        if 'cheating_partial' in effects:
            e = effects['cheating_partial']
            total_str = f"{e['total_mean']:+.2f} [{e['total_hdi'][0]:+.2f},{e['total_hdi'][1]:+.2f}]"
            direct_str = f"{e['direct_mean']:+.2f} [{e['direct_hdi'][0]:+.2f},{e['direct_hdi'][1]:+.2f}]"
            indirect_str = f"{e['indirect_mean']:+.2f} [{e['indirect_hdi'][0]:+.2f},{e['indirect_hdi'][1]:+.2f}]"
            mediated = f"{e['prop_mediated']*100:.0f}%"
            
            print(f"{concept:<20} {total_str:<20} {direct_str:<20} {indirect_str:<20} {mediated:<10}")
    
    print("\nCHEATING - FULL vs NON:")
    print(f"{'Concept':<20} {'Total':<20} {'Direct':<20} {'Indirect':<20} {'% Mediated':<10}")
    print("-"*90)
    
    for concept, effects in results.items():
        if 'cheating_full' in effects:
            e = effects['cheating_full']
            total_str = f"{e['total_mean']:+.2f} [{e['total_hdi'][0]:+.2f},{e['total_hdi'][1]:+.2f}]"
            direct_str = f"{e['direct_mean']:+.2f} [{e['direct_hdi'][0]:+.2f},{e['direct_hdi'][1]:+.2f}]"
            indirect_str = f"{e['indirect_mean']:+.2f} [{e['indirect_hdi'][0]:+.2f},{e['indirect_hdi'][1]:+.2f}]"
            mediated = f"{e['prop_mediated']*100:.0f}%"
            
            print(f"{concept:<20} {total_str:<20} {direct_str:<20} {indirect_str:<20} {mediated:<10}")

def get_mechanism_importance(trace, mechanisms):
    """Extract which mechanisms are most important for each outcome"""
    
    importance = {}
    
    # Mechanism effects on outcomes
    mech_to_partial = trace.posterior['mechanism_to_partial'].values.reshape(-1, len(mechanisms))
    mech_to_full = trace.posterior['mechanism_to_full'].values.reshape(-1, len(mechanisms))
    
    for i, mech in enumerate(mechanisms):
        importance[mech] = {
            'partial_mean': mech_to_partial[:, i].mean(),
            'partial_hdi': az.hdi(mech_to_partial[:, i], hdi_prob=0.95),
            'full_mean': mech_to_full[:, i].mean(), 
            'full_hdi': az.hdi(mech_to_full[:, i], hdi_prob=0.95)
        }
    
    return importance