# Part 1: Deep Learning for Corporate Finance - Results Visualization

This notebook loads trained model checkpoints and generates all figures for the report.

**Prerequisites:** Run `01_part1_training.ipynb` first to generate checkpoints.

**Output:** All figures are saved to `results/latest/figures/`

In [None]:
# =============================================================================
# 1.1 Imports and Configuration
# =============================================================================

import sys
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import glob

# Add project root to path
sys.path.insert(0, os.path.abspath(".."))

# Utility imports
from src.utils.checkpointing import load_training_result
from src.utils.analysis import (
    get_steady_state_policy, 
    evaluate_policy, 
    compute_frictionless_policy,
    print_training_summary
)
from src.utils.plotting import (
    plot_basic_loss_curves,
    plot_risky_loss_curves,
    plot_policy_comparison_panels,
    plot_scenario_comparison_panels,
    plot_policy_panels,
    plot_3d_policy_slice,
    plot_3d_policy_panels
)
from src.economy.parameters import EconomicParams, ShockParams

print("All imports successful!")

In [None]:
# =============================================================================
# 1.2 Load Configuration from Latest Run
# =============================================================================

# Use 'latest' symlink for stable path (points to most recent training run)
RESULTS_BASE = os.path.join("..", "results")
RESULTS_DIR = os.path.join(RESULTS_BASE, "latest")
FIGURES_DIR = os.path.join(RESULTS_DIR, "figures")
CHECKPOINTS_DIR = os.path.join(RESULTS_DIR, "checkpoints")

# Verify paths exist
if not os.path.exists(RESULTS_DIR):
    raise FileNotFoundError(
        f"Results directory not found: {RESULTS_DIR}\n"
        "Please run 01_part1_training.ipynb first to generate checkpoints."
    )

# Resolve symlink to show actual run name
if os.path.islink(RESULTS_DIR):
    actual_run = os.readlink(RESULTS_DIR)
    print(f"Loading results from: {RESULTS_DIR} -> {actual_run}")
else:
    print(f"Loading results from: {RESULTS_DIR}")

# Ensure figures directory exists
os.makedirs(FIGURES_DIR, exist_ok=True)
print(f"Figures will be saved to: {FIGURES_DIR}")

In [None]:
# =============================================================================
# 1.3 Load Scenario Metadata
# =============================================================================

# Load scenario definitions
scenarios_path = os.path.join(CHECKPOINTS_DIR, 'scenarios.json')
with open(scenarios_path, 'r') as f:
    scenarios_meta = json.load(f)

basic_scenarios = scenarios_meta['basic_scenarios']
risky_scenarios = scenarios_meta['risky_scenarios']
methods = scenarios_meta['methods']

print(f"Scenarios loaded:")
print(f"  Basic: {basic_scenarios}")
print(f"  Risky: {risky_scenarios}")
print(f"  Methods: {methods}")

# Load bounds
bounds_path = os.path.join(CHECKPOINTS_DIR, 'bounds.json')
with open(bounds_path, 'r') as f:
    bounds_loaded = json.load(f)

k_bounds = tuple(bounds_loaded['k'])
logz_bounds = tuple(bounds_loaded['log_z'])
b_bounds = tuple(bounds_loaded['b'])

print(f"\nBounds loaded: k={k_bounds}, log_z={logz_bounds}, b={b_bounds}")

In [None]:
# =============================================================================
# 1.4 Load All Training Results
# =============================================================================

# Load basic model results
print("=" * 70)
print("LOADING BASIC MODEL RESULTS")
print("=" * 70)

results_basic = {}
for scenario_name in basic_scenarios:
    results_basic[scenario_name] = {}
    print(f"\nScenario: {scenario_name}")
    
    for method_name in methods:
        checkpoint_dir = os.path.join(CHECKPOINTS_DIR, "basic", scenario_name, method_name)
        result = load_training_result(checkpoint_dir, verbose=False)
        results_basic[scenario_name][method_name] = result
        print(f"  Loaded {method_name}")

# Load risky model results
print("\n" + "=" * 70)
print("LOADING RISKY DEBT RESULTS")
print("=" * 70)

results_risky = {}
for scenario_name in risky_scenarios:
    checkpoint_dir = os.path.join(CHECKPOINTS_DIR, "risky", scenario_name)
    result = load_training_result(checkpoint_dir, verbose=False)
    results_risky[scenario_name] = result
    print(f"  Loaded {scenario_name}")

print(f"\nLoaded {len(basic_scenarios) * len(methods)} basic + {len(risky_scenarios)} risky results.")

## 2. Basic Model Analysis

In [None]:
# =============================================================================
# 3.4.1 Per-Scenario Loss Curves (6-panel format)
# =============================================================================

from src.utils.plotting import plot_basic_loss_curves

for scenario_name in basic_scenarios:
    scenario_results = results_basic[scenario_name]
    
    fig = plot_basic_loss_curves(
        result_lr=scenario_results['lr'],
        result_er=scenario_results['er'],
        result_br=scenario_results['br'],
        scenario_name=scenario_name,
        save_path=os.path.join(FIGURES_DIR, f"basic_loss_curves_{scenario_name}.png")
    )
    plt.show()
    print(f"Saved: basic_loss_curves_{scenario_name}.png")


In [None]:
# =============================================================================
# 3.4.2 Cross-Scenario Summary Table
# =============================================================================

from src.utils.analysis import print_training_summary

# Clean summary table using utility function
print_training_summary(results_basic, model_type="basic")

# Build detailed summary with pandas for additional analysis
summary_data = []
for scenario_name in basic_scenarios:
    scenario_results = results_basic[scenario_name]
    
    for method_name in methods:
        result = scenario_results[method_name]
        
        # Get final loss based on method
        if method_name == 'lr':
            final_loss = result['history']['loss_LR'][-1]
            loss_name = 'LR Loss'
        elif method_name == 'er':
            final_loss = result['history']['loss_ER'][-1]
            loss_name = 'ER Loss'
        else:  # br
            # Use rel_mse for BR as it's more meaningful
            final_loss = result['history'].get('rel_mse', result['history']['loss_critic'])[-1]
            loss_name = 'BR RelMSE'
        
        summary_data.append({
            'Scenario': scenario_name,
            'Method': method_name.upper(),
            'Final Loss': final_loss,
            'Loss Type': loss_name
        })

summary_df = pd.DataFrame(summary_data)

# Pivot table for easier comparison
pivot_df = summary_df.pivot(index='Scenario', columns='Method', values='Final Loss')

print("\nPivot Table (Final Losses):")
print(pivot_df.to_string())


In [None]:
# =============================================================================
# 3.4.3 Per-Scenario Policy Comparison (LR vs ER vs BR)
# =============================================================================

from src.utils.analysis import get_steady_state_policy, evaluate_policy
from src.utils.plotting import plot_policy_comparison_panels
from src.economy.parameters import EconomicParams, ShockParams

n_scenarios = len(basic_scenarios)

for scenario_name in basic_scenarios:
    print(f"\n{'=' * 60}")
    print(f"Scenario: {scenario_name}")
    print(f"{'=' * 60}")
    
    scenario_results = results_basic[scenario_name]
    
    # Load params that were used during training for this scenario
    params_path = os.path.join(CHECKPOINTS_DIR, "basic", scenario_name, 'params.json')
    with open(params_path, 'r') as f:
        saved_params = json.load(f)
    
    scenario_params = EconomicParams(**saved_params['params'])
    scenario_shock_params = ShockParams(**saved_params['shock_params'])
    
    print(f"Loaded params: cost_convex={scenario_params.cost_convex}, cost_fixed={scenario_params.cost_fixed}")
    
    # Find steady states
    ss_lr = get_steady_state_policy(scenario_results['lr'], k_bounds=k_bounds, logz_bounds=logz_bounds)
    ss_er = get_steady_state_policy(scenario_results['er'], k_bounds=k_bounds, logz_bounds=logz_bounds)
    ss_br = get_steady_state_policy(scenario_results['br'], k_bounds=k_bounds, logz_bounds=logz_bounds)
    
    print(f"Steady State Capital (k*):")
    print(f"  LR: {ss_lr['k_star_val']:.4f}")
    print(f"  ER: {ss_er['k_star_val']:.4f}")
    print(f"  BR: {ss_br['k_star_val']:.4f}")
    
    # Evaluate policies (use BR's steady state as reference)
    grid_lr = evaluate_policy(scenario_results['lr'], k_bounds=k_bounds, logz_bounds=logz_bounds,
                              fixed_k_val=ss_br['k_star_val'])
    grid_er = evaluate_policy(scenario_results['er'], k_bounds=k_bounds, logz_bounds=logz_bounds,
                              fixed_k_val=ss_br['k_star_val'])
    grid_br = evaluate_policy(scenario_results['br'], k_bounds=k_bounds, logz_bounds=logz_bounds,
                              fixed_k_val=ss_br['k_star_val'])
    
    # Add frictionless benchmark using the scenario's actual params
    benchmark = {'params': scenario_params, 'shock_params': scenario_shock_params}
    
    # Comparison plot
    fig = plot_policy_comparison_panels(
        [grid_lr, grid_er, grid_br],
        labels=['Lifetime Rewards', 'Euler Residuals', 'Bellman Residuals'],
        suptitle=f"Policy Comparison: {scenario_name}",
        frictionless_benchmark=benchmark
    )
    plt.tight_layout()
    plt.savefig(os.path.join(FIGURES_DIR, f"basic_policy_comparison_{scenario_name}.png"), 
                dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# =============================================================================
# 3.4.4 Cross-Scenario Policy Comparison (I/k and b'/k Transformations)
# =============================================================================
# Compare transformed policies across scenarios for each training method.
# - Investment rate: I/k = k'/k - (1-δ)
# - Leverage ratio: b'/k (risky model only)

from src.utils.plotting import plot_scenario_comparison_panels
from src.utils.analysis import get_steady_state_policy, evaluate_policy

# Load delta from first scenario's saved params (should be consistent across scenarios)
first_scenario = basic_scenarios[0]
params_path = os.path.join(CHECKPOINTS_DIR, "basic", first_scenario, 'params.json')
with open(params_path, 'r') as f:
    saved_params = json.load(f)
delta = saved_params['params']['delta']

print(f"Depreciation rate (delta) loaded from {first_scenario} params: {delta}")
print(f"\nGenerating cross-scenario comparison plots for each method...")

for method_name in methods:
    print(f"\n{'=' * 60}")
    print(f"Method: {method_name.upper()}")
    print(f"{'=' * 60}")
    
    # Collect eval_datas for all scenarios
    scenario_eval_datas = {}
    
    for scenario_name in basic_scenarios:
        scenario_results = results_basic[scenario_name]
        result = scenario_results[method_name]
        
        # Find steady state for this scenario/method combination
        ss = get_steady_state_policy(result, k_bounds=k_bounds, logz_bounds=logz_bounds)
        
        # Evaluate policy at its own steady state
        grid = evaluate_policy(
            result, 
            k_bounds=k_bounds, 
            logz_bounds=logz_bounds,
            fixed_k_val=ss['k_star_val']
        )
        
        scenario_eval_datas[scenario_name] = grid
        print(f"  {scenario_name}: k* = {ss['k_star_val']:.4f}")
    
    # Create cross-scenario comparison plot
    fig = plot_scenario_comparison_panels(
        scenario_eval_datas,
        delta=delta,
        suptitle=f"{method_name.upper()} Method: Cross-Scenario Policy Comparison"
    )
    
    # Save figure
    save_path = os.path.join(FIGURES_DIR, f"basic_scenario_comparison_{method_name}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"  Saved: {os.path.basename(save_path)}")

print(f"\n{'=' * 60}")
print("Cross-scenario comparison plots complete!")
print(f"{'=' * 60}")


In [None]:
# =============================================================================
# 3.4.5 Cross-Scenario BR Policy Comparison (3D Surfaces)
# =============================================================================

from src.utils.plotting import plot_3d_policy_slice
from src.utils.analysis import get_steady_state_policy, evaluate_policy

# Compare BR policies across scenarios
n_scenarios = len(basic_scenarios)
fig = plt.figure(figsize=(6 * n_scenarios, 5))

for idx, scenario_name in enumerate(basic_scenarios):
    scenario_results = results_basic[scenario_name]
    
    # Get BR result and evaluate
    ss = get_steady_state_policy(scenario_results['br'], k_bounds=k_bounds, logz_bounds=logz_bounds)
    grid = evaluate_policy(scenario_results['br'], k_bounds=k_bounds, logz_bounds=logz_bounds,
                          fixed_k_val=ss['k_star_val'])
    
    # 3D surface plot
    ax = fig.add_subplot(1, n_scenarios, idx + 1, projection='3d')
    plot_3d_policy_slice(
        grid, y_var='k_next', ax=ax,
        title=f'{scenario_name}\nk\'(k, z)',
        show_colorbar=False,
        show_contour_projection=True
    )

fig.suptitle("Cross-Scenario Comparison: BR Investment Policies", fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(FIGURES_DIR, "basic_3d_cross_scenario.png"), dpi=150, bbox_inches='tight')
plt.show()

## 3. Risky Debt Model Analysis

In [None]:
# =============================================================================
# 3.5.1 Risky Debt BR: Loss Curves
# =============================================================================

from src.utils.plotting import plot_risky_loss_curves

# Use baseline scenario (or first available)
risky_scenario = risky_scenarios[0]
result_risky_br = results_risky[risky_scenario]

fig, summary = plot_risky_loss_curves(
    result_risky=result_risky_br,
    scenario_name=risky_scenario,
    save_path=os.path.join(FIGURES_DIR, "risky_loss_curves.png")
)
plt.show()

# Print summary statistics
print(f"\nFinal Training Metrics ({risky_scenario}):")
print(f"  Relative MSE:     {summary['rel_mse']:.6f}")
print(f"  Actor Loss:       {summary['loss_actor']:.4f}")
print(f"  Price Loss:       {summary['loss_price']:.6f}")
print(f"  Value Scale:      {summary['mean_value_scale']:.2f}")
print("Saved: risky_loss_curves.png")


In [None]:
# =============================================================================
# 3.5.2 Risky Debt Policy Visualization (2D Panels)
# =============================================================================

from src.utils.plotting import plot_policy_panels
from src.utils.analysis import get_steady_state_policy, evaluate_policy

# Find steady state for risky model
ss_risky = get_steady_state_policy(result_risky_br, k_bounds=k_bounds, 
                                    logz_bounds=logz_bounds, b_bounds=b_bounds)

print(f"Risky Debt Steady State ({risky_scenario}):")
print(f"  k* = {ss_risky['k_star_val']:.4f}")
print(f"  b* = {ss_risky['b_star_val']:.4f}")

# Evaluate policy on grid
grid_risky = evaluate_policy(result_risky_br, k_bounds=k_bounds, 
                             logz_bounds=logz_bounds, b_bounds=b_bounds,
                             fixed_k_val=ss_risky['k_star_val'],
                             fixed_b_val=ss_risky['b_star_val'])

# 2x3 panel plot for risky model
fig = plot_policy_panels(grid_risky, suptitle=f"Risky Debt BR Policy ({risky_scenario})")
plt.savefig(os.path.join(FIGURES_DIR, "risky_br_policy.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# =============================================================================
# 3.5.3 Risky Debt: 3D Policy Surface Visualization
# =============================================================================

from src.utils.plotting import plot_3d_policy_panels

# Risky Model 3D Surfaces
fig = plot_3d_policy_panels(grid_risky,
                            suptitle=f"Risky Debt Model: Policy Surfaces ({risky_scenario})",
                            show_contour_projection=True)
plt.savefig(os.path.join(FIGURES_DIR, "risky_3d_policies.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# =============================================================================
# 4. Summary: All Generated Figures
# =============================================================================

print("=" * 70)
print("GENERATED FIGURES")
print("=" * 70)

figures = sorted(glob.glob(os.path.join(FIGURES_DIR, "*.png")))
print(f"\nTotal figures: {len(figures)}")
print(f"Directory: {FIGURES_DIR}\n")

for fig_path in figures:
    print(f"  - {os.path.basename(fig_path)}")