# Bayesian Causal Analysis on Titanic Dataset

This notebook demonstrates rudimentary causal analysis using conditional probabilities on categorical data.
We'll explore various causal questions like:
- Does passenger sex affect survival probability?
- Does passenger class influence survival?
- How do age groups affect survival rates?
- What's the combined effect of multiple factors?

## Key Concept: Conditional Probability as Causal Insight

While correlation ≠ causation, conditional probabilities like P(Survival|Sex) can reveal causal patterns when:
1. We have domain knowledge (e.g., "women and children first" policy)
2. We can control for confounders
3. We examine consistency across subgroups

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load and Prepare Data

In [None]:
# Load Titanic dataset
url = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'
df = pd.read_csv(url)

print(f"Dataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Data preparation
# Create age groups
df['AgeGroup'] = pd.cut(df['Age'], 
                        bins=[0, 12, 18, 35, 60, 100], 
                        labels=['Child', 'Teen', 'Adult', 'Middle-aged', 'Senior'])

# Create fare categories
df['FareCategory'] = pd.qcut(df['Fare'], 
                              q=4, 
                              labels=['Low', 'Medium', 'High', 'Very High'],
                              duplicates='drop')

# Clean data
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)

print("Data preprocessing complete!")
print(f"\nMissing values:\n{df[['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']].isnull().sum()}")

## 2. Foundational Functions for Causal Analysis

In [None]:
def calculate_conditional_probability(df, outcome_col, condition_col, outcome_value=1):
    """
    Calculate P(Outcome | Condition) for each category in condition_col
    
    Args:
        df: DataFrame
        outcome_col: Target variable (e.g., 'Survived')
        condition_col: Conditioning variable (e.g., 'Sex')
        outcome_value: Value of outcome to calculate probability for (default: 1)
    
    Returns:
        DataFrame with probabilities and counts
    """
    # Remove rows with missing values
    clean_df = df[[outcome_col, condition_col]].dropna()
    
    results = []
    for category in clean_df[condition_col].unique():
        subset = clean_df[clean_df[condition_col] == category]
        total = len(subset)
        successes = (subset[outcome_col] == outcome_value).sum()
        probability = successes / total if total > 0 else 0
        
        # Calculate 95% confidence interval (Wilson score interval)
        if total > 0:
            z = 1.96  # 95% confidence
            p_hat = probability
            denominator = 1 + z**2/total
            centre = (p_hat + z**2/(2*total)) / denominator
            margin = z * np.sqrt(p_hat*(1-p_hat)/total + z**2/(4*total**2)) / denominator
            ci_lower = centre - margin
            ci_upper = centre + margin
        else:
            ci_lower = ci_upper = 0
        
        results.append({
            'Category': category,
            'Probability': probability,
            'Count': successes,
            'Total': total,
            'CI_Lower': ci_lower,
            'CI_Upper': ci_upper
        })
    
    return pd.DataFrame(results).sort_values('Probability', ascending=False)


def calculate_joint_probability(df, outcome_col, condition_cols, outcome_value=1):
    """
    Calculate P(Outcome | Condition1, Condition2, ...)
    
    Args:
        df: DataFrame
        outcome_col: Target variable
        condition_cols: List of conditioning variables
        outcome_value: Value of outcome to calculate probability for
    
    Returns:
        DataFrame with joint probabilities
    """
    # Remove rows with missing values
    all_cols = [outcome_col] + condition_cols
    clean_df = df[all_cols].dropna()
    
    # Group by all condition columns
    grouped = clean_df.groupby(condition_cols)
    
    results = []
    for name, group in grouped:
        total = len(group)
        successes = (group[outcome_col] == outcome_value).sum()
        probability = successes / total if total > 0 else 0
        
        # Create category label
        if isinstance(name, tuple):
            category = ' & '.join([f"{col}={val}" for col, val in zip(condition_cols, name)])
        else:
            category = f"{condition_cols[0]}={name}"
        
        results.append({
            'Category': category,
            'Probability': probability,
            'Count': successes,
            'Total': total
        })
    
    return pd.DataFrame(results).sort_values('Probability', ascending=False)


def visualize_conditional_prob(prob_df, title, xlabel):
    """
    Visualize conditional probabilities with confidence intervals
    """
    fig, ax = plt.subplots(figsize=(10, 6))
    
    categories = prob_df['Category'].astype(str)
    probabilities = prob_df['Probability']
    
    # Create bar plot
    bars = ax.bar(range(len(categories)), probabilities, alpha=0.7, color='steelblue')
    
    # Add error bars if confidence intervals exist
    if 'CI_Lower' in prob_df.columns:
        errors = [probabilities - prob_df['CI_Lower'], 
                 prob_df['CI_Upper'] - probabilities]
        ax.errorbar(range(len(categories)), probabilities, 
                   yerr=errors, fmt='none', ecolor='black', 
                   capsize=5, alpha=0.5)
    
    # Add value labels on bars
    for i, (bar, prob, count, total) in enumerate(zip(bars, probabilities, 
                                                       prob_df['Count'], prob_df['Total'])):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{prob:.1%}\n({count}/{total})',
                ha='center', va='bottom', fontsize=9)
    
    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel('P(Survival)', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xticks(range(len(categories)))
    ax.set_xticklabels(categories, rotation=45, ha='right')
    ax.set_ylim(0, 1)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("Helper functions defined successfully!")

## 3. Exploratory Causal Question 1: Does Sex Affect Survival?

**Causal Question**: P(Survival | Sex)

This is the classic "Women and Children First" hypothesis.

In [None]:
# Calculate P(Survival | Sex)
sex_survival = calculate_conditional_probability(df, 'Survived', 'Sex')
print("P(Survival | Sex):\n")
print(sex_survival.to_string(index=False))
print(f"\nRelative Risk (Female vs Male): {sex_survival.iloc[0]['Probability'] / sex_survival.iloc[1]['Probability']:.2f}x")

In [None]:
# Visualize
visualize_conditional_prob(sex_survival, 
                          'Survival Probability by Sex',
                          'Sex')

In [None]:
# Statistical test: Chi-square test of independence
contingency_table = pd.crosstab(df['Sex'], df['Survived'])
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)

print("\nContingency Table:")
print(contingency_table)
print(f"\nChi-square statistic: {chi2:.4f}")
print(f"P-value: {p_value:.4e}")
print(f"Degrees of freedom: {dof}")
print(f"\nConclusion: Sex {'IS' if p_value < 0.05 else 'IS NOT'} significantly associated with survival (α=0.05)")

## 4. Exploratory Causal Question 2: Does Passenger Class Affect Survival?

**Causal Question**: P(Survival | Pclass)

Does socioeconomic status (proxied by ticket class) influence survival?

In [None]:
# Calculate P(Survival | Pclass)
class_survival = calculate_conditional_probability(df, 'Survived', 'Pclass')
print("P(Survival | Passenger Class):\n")
print(class_survival.to_string(index=False))

In [None]:
# Visualize
visualize_conditional_prob(class_survival,
                          'Survival Probability by Passenger Class',
                          'Passenger Class')

In [None]:
# Statistical test
contingency_table = pd.crosstab(df['Pclass'], df['Survived'])
chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)

print("\nContingency Table:")
print(contingency_table)
print(f"\nChi-square statistic: {chi2:.4f}")
print(f"P-value: {p_value:.4e}")
print(f"\nConclusion: Passenger class {'IS' if p_value < 0.05 else 'IS NOT'} significantly associated with survival (α=0.05)")

## 5. Exploratory Causal Question 3: Does Age Group Affect Survival?

**Causal Question**: P(Survival | AgeGroup)

Testing the "children first" component of the evacuation policy.

In [None]:
# Calculate P(Survival | AgeGroup)
age_survival = calculate_conditional_probability(df, 'Survived', 'AgeGroup')
print("P(Survival | Age Group):\n")
print(age_survival.to_string(index=False))

In [None]:
# Visualize
visualize_conditional_prob(age_survival,
                          'Survival Probability by Age Group',
                          'Age Group')

## 6. Exploratory Causal Question 4: Does Port of Embarkation Affect Survival?

**Causal Question**: P(Survival | Embarked)

Could where passengers boarded affect their survival? (Might be a proxy for wealth or cabin location)

In [None]:
# Calculate P(Survival | Embarked)
embarked_survival = calculate_conditional_probability(df, 'Survived', 'Embarked')
print("P(Survival | Port of Embarkation):\n")
print(embarked_survival.to_string(index=False))
print("\nPorts: C=Cherbourg, Q=Queenstown, S=Southampton")

In [None]:
# Visualize
visualize_conditional_prob(embarked_survival,
                          'Survival Probability by Port of Embarkation',
                          'Port')

## 7. Exploratory Causal Question 5: Does Family Size Affect Survival?

**Causal Question**: P(Survival | FamilySize)

Did traveling with family help or hinder survival?

In [None]:
# Create family size variable
df['FamilySize'] = df['SibSp'] + df['Parch'] + 1  # +1 for the passenger themselves
df['FamilyCategory'] = pd.cut(df['FamilySize'],
                              bins=[0, 1, 3, 6, 20],
                              labels=['Alone', 'Small (2-3)', 'Medium (4-6)', 'Large (7+)'])

# Calculate P(Survival | FamilySize)
family_survival = calculate_conditional_probability(df, 'Survived', 'FamilyCategory')
print("P(Survival | Family Size):\n")
print(family_survival.to_string(index=False))

In [None]:
# Visualize
visualize_conditional_prob(family_survival,
                          'Survival Probability by Family Size',
                          'Family Size Category')

## 8. Joint Causal Analysis: Controlling for Confounders

**Causal Question**: P(Survival | Sex, Pclass)

Does the sex effect hold across all passenger classes?

In [None]:
# Calculate joint probability
sex_class_survival = calculate_joint_probability(df, 'Survived', ['Sex', 'Pclass'])
print("P(Survival | Sex, Passenger Class):\n")
print(sex_class_survival.to_string(index=False))

In [None]:
# Create a more detailed visualization
fig, ax = plt.subplots(figsize=(12, 6))

# Prepare data for grouped bar chart
clean_df = df[['Survived', 'Sex', 'Pclass']].dropna()
pivot_data = clean_df.groupby(['Pclass', 'Sex'])['Survived'].agg(['mean', 'count']).reset_index()

# Create grouped bar chart
x = np.arange(3)  # 3 classes
width = 0.35

male_data = pivot_data[pivot_data['Sex'] == 'male'].sort_values('Pclass')
female_data = pivot_data[pivot_data['Sex'] == 'female'].sort_values('Pclass')

bars1 = ax.bar(x - width/2, male_data['mean'], width, label='Male', alpha=0.7, color='steelblue')
bars2 = ax.bar(x + width/2, female_data['mean'], width, label='Female', alpha=0.7, color='coral')

# Add value labels
for bars, data in [(bars1, male_data), (bars2, female_data)]:
    for bar, prob, count in zip(bars, data['mean'], data['count']):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{prob:.1%}\n(n={count})',
                ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Passenger Class', fontsize=12)
ax.set_ylabel('P(Survival)', fontsize=12)
ax.set_title('Survival Probability by Sex and Passenger Class', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(['1st Class', '2nd Class', '3rd Class'])
ax.legend()
ax.set_ylim(0, 1)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Insight: The 'women first' effect holds across all classes,")
print("but class still matters - even among women, 1st class had better survival.")

## 9. Three-way Interaction: Sex × Class × Age

**Causal Question**: P(Survival | Sex, Pclass, AgeGroup)

Full stratification to understand the combined effects.

In [None]:
# Focus on children vs adults
df['IsChild'] = (df['Age'] < 18).map({True: 'Child', False: 'Adult'})

# Calculate three-way joint probability
triple_survival = calculate_joint_probability(df, 'Survived', ['Sex', 'Pclass', 'IsChild'])
print("P(Survival | Sex, Passenger Class, Child/Adult):\n")
print(triple_survival.to_string(index=False))

In [None]:
# Visualize as heatmap
clean_df = df[['Survived', 'Sex', 'Pclass', 'IsChild']].dropna()
pivot = clean_df.groupby(['Pclass', 'Sex', 'IsChild'])['Survived'].mean().unstack(fill_value=0)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for idx, sex in enumerate(['female', 'male']):
    data = pivot.xs(sex, level='Sex')
    sns.heatmap(data, annot=True, fmt='.2%', cmap='RdYlGn', 
                vmin=0, vmax=1, ax=axes[idx], cbar_kws={'label': 'Survival Rate'})
    axes[idx].set_title(f'{sex.capitalize()} Passengers', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Age Category', fontsize=10)
    axes[idx].set_ylabel('Passenger Class', fontsize=10)

plt.tight_layout()
plt.show()

print("\nKey Insights:")
print("1. Female children had highest survival across all classes")
print("2. Male adults in 3rd class had lowest survival")
print("3. Class mattered even for children - 1st class children survived more")

## 10. Bayesian Interpretation with Priors

Let's take a Bayesian approach: what if we had prior beliefs about survival rates?

In [None]:
from scipy.stats import beta

def bayesian_probability_estimate(successes, trials, prior_alpha=1, prior_beta=1):
    """
    Bayesian estimate of probability using Beta-Binomial conjugate prior
    
    Args:
        successes: Number of positive outcomes
        trials: Total number of trials
        prior_alpha: Prior alpha parameter (pseudo-successes)
        prior_beta: Prior beta parameter (pseudo-failures)
    
    Returns:
        Dictionary with posterior statistics
    """
    # Posterior parameters
    posterior_alpha = prior_alpha + successes
    posterior_beta = prior_beta + (trials - successes)
    
    # Posterior statistics
    posterior_mean = posterior_alpha / (posterior_alpha + posterior_beta)
    posterior_mode = (posterior_alpha - 1) / (posterior_alpha + posterior_beta - 2) if (posterior_alpha > 1 and posterior_beta > 1) else posterior_mean
    
    # 95% credible interval
    ci_lower = beta.ppf(0.025, posterior_alpha, posterior_beta)
    ci_upper = beta.ppf(0.975, posterior_alpha, posterior_beta)
    
    return {
        'posterior_mean': posterior_mean,
        'posterior_mode': posterior_mode,
        'ci_lower': ci_lower,
        'ci_upper': ci_upper,
        'posterior_alpha': posterior_alpha,
        'posterior_beta': posterior_beta
    }

# Example: Bayesian estimate for female survival
female_data = df[df['Sex'] == 'female']
female_survived = female_data['Survived'].sum()
female_total = len(female_data)

# Uninformative prior (alpha=1, beta=1 = uniform)
uniform_prior = bayesian_probability_estimate(female_survived, female_total, 1, 1)

# Informative prior: we believe survival rate is around 50% (alpha=5, beta=5)
informative_prior = bayesian_probability_estimate(female_survived, female_total, 5, 5)

print("Bayesian Estimates for Female Survival:\n")
print(f"Data: {female_survived}/{female_total} survived = {female_survived/female_total:.1%}")
print(f"\nUniform Prior (α=1, β=1):")
print(f"  Posterior Mean: {uniform_prior['posterior_mean']:.1%}")
print(f"  95% Credible Interval: [{uniform_prior['ci_lower']:.1%}, {uniform_prior['ci_upper']:.1%}]")
print(f"\nInformative Prior (α=5, β=5, centered at 50%):")
print(f"  Posterior Mean: {informative_prior['posterior_mean']:.1%}")
print(f"  95% Credible Interval: [{informative_prior['ci_lower']:.1%}, {informative_prior['ci_upper']:.1%}]")

In [None]:
# Visualize prior and posterior distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

x = np.linspace(0, 1, 1000)

# Plot 1: Uniform prior
prior_uniform = beta.pdf(x, 1, 1)
posterior_uniform = beta.pdf(x, uniform_prior['posterior_alpha'], uniform_prior['posterior_beta'])

axes[0].plot(x, prior_uniform, 'b--', label='Prior (Uniform)', linewidth=2)
axes[0].plot(x, posterior_uniform, 'r-', label='Posterior', linewidth=2)
axes[0].axvline(female_survived/female_total, color='g', linestyle=':', 
                label=f'MLE = {female_survived/female_total:.1%}', linewidth=2)
axes[0].fill_between(x, 0, posterior_uniform, alpha=0.2, color='red')
axes[0].set_xlabel('Survival Probability', fontsize=12)
axes[0].set_ylabel('Density', fontsize=12)
axes[0].set_title('Uniform Prior → Posterior', fontsize=12, fontweight='bold')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Plot 2: Informative prior
prior_informative = beta.pdf(x, 5, 5)
posterior_informative = beta.pdf(x, informative_prior['posterior_alpha'], informative_prior['posterior_beta'])

axes[1].plot(x, prior_informative, 'b--', label='Prior (centered at 50%)', linewidth=2)
axes[1].plot(x, posterior_informative, 'r-', label='Posterior', linewidth=2)
axes[1].axvline(female_survived/female_total, color='g', linestyle=':', 
                label=f'MLE = {female_survived/female_total:.1%}', linewidth=2)
axes[1].fill_between(x, 0, posterior_informative, alpha=0.2, color='red')
axes[1].set_xlabel('Survival Probability', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].set_title('Informative Prior → Posterior', fontsize=12, fontweight='bold')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("- With strong data (n=314), even an informative prior gets overwhelmed")
print("- The posterior is data-driven and concentrates around the observed rate")
print("- Bayesian approach gives us full uncertainty quantification via credible intervals")

## 11. Comparative Summary: All Causal Factors

In [None]:
# Create comprehensive comparison
factors_comparison = []

# Add all single-factor analyses
for factor, name in [('Sex', 'Sex'), ('Pclass', 'Class'), 
                     ('AgeGroup', 'Age Group'), ('Embarked', 'Port'),
                     ('FamilyCategory', 'Family Size')]:
    if factor in df.columns:
        prob_df = calculate_conditional_probability(df, 'Survived', factor)
        max_prob = prob_df['Probability'].max()
        min_prob = prob_df['Probability'].min()
        effect_size = max_prob - min_prob
        
        factors_comparison.append({
            'Factor': name,
            'Max Survival': max_prob,
            'Min Survival': min_prob,
            'Effect Size': effect_size,
            'Categories': len(prob_df)
        })

comparison_df = pd.DataFrame(factors_comparison).sort_values('Effect Size', ascending=False)
print("Comparative Effect Sizes:\n")
print(comparison_df.to_string(index=False))

In [None]:
# Visualize comparative effect sizes
fig, ax = plt.subplots(figsize=(10, 6))

y_pos = np.arange(len(comparison_df))
bars = ax.barh(y_pos, comparison_df['Effect Size'], alpha=0.7, color='steelblue')

# Add value labels
for i, (bar, effect) in enumerate(zip(bars, comparison_df['Effect Size'])):
    ax.text(effect, bar.get_y() + bar.get_height()/2,
            f'{effect:.1%}',
            va='center', ha='left', fontsize=10, fontweight='bold')

ax.set_yticks(y_pos)
ax.set_yticklabels(comparison_df['Factor'])
ax.set_xlabel('Effect Size (Max - Min Survival Probability)', fontsize=12)
ax.set_title('Comparative Causal Effect Sizes on Titanic Survival', 
             fontsize=14, fontweight='bold')
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey Findings:")
print(f"1. Strongest effect: {comparison_df.iloc[0]['Factor']} ({comparison_df.iloc[0]['Effect Size']:.1%} difference)")
print(f"2. Second strongest: {comparison_df.iloc[1]['Factor']} ({comparison_df.iloc[1]['Effect Size']:.1%} difference)")
print(f"3. Weakest effect: {comparison_df.iloc[-1]['Factor']} ({comparison_df.iloc[-1]['Effect Size']:.1%} difference)")

## 12. Summary and Causal Insights

### Key Causal Findings:

1. **Sex has the strongest causal effect** - Being female dramatically increased survival probability ("women and children first" policy)

2. **Passenger class matters** - Even controlling for sex, higher class passengers had better survival rates (proximity to lifeboats, priority access)

3. **Age shows moderate effect** - Children had higher survival, but the effect is weaker than sex or class

4. **Family size shows non-linear relationship** - Small families did best, being alone or in large families reduced survival

5. **Port of embarkation is likely a confound** - This probably proxies for wealth/class rather than having direct causal impact

### Causal vs Correlational Thinking:

- **Correlation**: Port C had higher survival than Port S
- **Causal thinking**: Port C passengers were wealthier (more 1st class) → wealth caused better survival, not the port itself

### Bayesian Perspective Benefits:

1. **Uncertainty quantification**: Not just point estimates but full probability distributions
2. **Prior knowledge integration**: Can incorporate domain expertise
3. **Natural handling of small samples**: Priors prevent extreme estimates with limited data
4. **Credible intervals**: Direct probability statements ("95% probability the true rate is in this range")

### Next Steps for Deeper Analysis:

- Implement logistic regression for multivariate causal modeling
- Use propensity score matching to better isolate causal effects
- Build Bayesian hierarchical models for group-level effects
- Explore causal graphs (DAGs) to formalize assumptions

In [None]:
# Final summary statistics
print("="*60)
print("TITANIC CAUSAL ANALYSIS SUMMARY")
print("="*60)
print(f"\nTotal passengers analyzed: {len(df)}")
print(f"Overall survival rate: {df['Survived'].mean():.1%}")
print(f"\nStrongest causal factors (by effect size):")
for i, row in comparison_df.head(3).iterrows():
    print(f"  {i+1}. {row['Factor']}: {row['Effect Size']:.1%} effect")
print(f"\nAll statistical tests showed p < 0.001, indicating")
print(f"significant associations between factors and survival.")
print("\nRemember: Association ≠ Causation")
print("These conditional probabilities suggest causal relationships,")
print("but true causality requires additional domain knowledge and")
print("careful consideration of confounding variables.")
print("="*60)