# Krusell and Smith (1998) with Deep Equilibrium Nets (PyTorch Version)

This notebook solves the classic heterogeneous agent model from Krusell and Smith (1998) using the **Deep Equilibrium Net (DEQN)** method (Azinovic, Gaegauf, and Scheidegger, 2022).

**Framework Adaptation:**
This is a port of the original implementation by Jan Žemlička. While the original used Google JAX, this version is rewritten using **PyTorch**. 

**Methodology:**
1. **Calibration (SSJ):** We use the `sequence-jacobian` toolbox (which uses JAX) to compute the steady state and calibrate grids. This is the most efficient way to initialize the model.
2. **Deep Learning (PyTorch):** We convert the steady-state objects to PyTorch tensors and use a PyTorch neural network to solve the global dynamics (Global Solution).

**Hardware:**
A GPU is highly recommended for the distribution transport steps.

In [None]:
# Install dependencies
# We need jax/sequence-jacobian for initialization, and torch for the solution.
!pip install sequence-jacobian jax jaxlib
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib scipy

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import sequence_jacobian as ssj
import time
import copy

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set precision to float32 (standard for DL) or float64 (standard for Econ)
# float32 is usually sufficient for DEQN and much faster on consumer GPUs
torch.set_default_dtype(torch.float32)

## Part 1: Initialization via Sequence-Jacobian
We use the standard SSJ toolbox to solve for the steady state. This ensures our grids and parameters are economically valid before we start training the neural net.

In [None]:
''' Define the SSJ Blocks (Standard Economics Logic) '''

# 1. Household Block
def household_init(a_grid, e_grid, r, w, eis):
    coh = (1 + r) * a_grid[np.newaxis, :] + w * e_grid[:, np.newaxis]
    Va = (1 + r) * (0.1 * coh) ** (-1 / eis)
    return Va

def make_grid(rho_e, sd_e, nE, amin, amax, nA):
    e_grid, pi_e, Pi = ssj.grids.markov_rouwenhorst(rho=rho_e, sigma=sd_e, N=nE)
    a_grid = ssj.grids.agrid(amin=amin, amax=amax, n=nA)
    return e_grid, Pi, a_grid, pi_e

@ssj.het(exogenous='Pi', policy='a', backward='Va', backward_init=household_init)
def household(Va_p, a_grid, e_grid, r, w, beta, eis):
    uc_nextgrid = beta * Va_p
    c_nextgrid = uc_nextgrid ** (-eis)
    coh = (1 + r) * a_grid[np.newaxis, :] + w * e_grid[:, np.newaxis]
    a = ssj.interpolate.interpolate_y(c_nextgrid + a_grid, coh, a_grid)
    ssj.misc.setmin(a, a_grid[0])
    c = coh - a
    Va = (1 + r) * c ** (-1 / eis)
    return Va, a, c

household = household.add_hetinputs([make_grid])

# 2. Firm and Market Clearing
@ssj.simple
def firm(K, L, Z, alpha, delta):
    r = alpha * Z * (K(-1) / L) ** (alpha-1) - delta
    w = (1 - alpha) * Z * (K(-1) / L) ** alpha
    Y = Z * K(-1) ** alpha * L ** (1 - alpha)
    return r, w, Y

@ssj.simple
def mkt_clearing(K, A, Y, C, delta):
    asset_mkt = A - K
    goods_mkt = Y - C - delta * K
    return asset_mkt, goods_mkt

# 3. Solve Steady State
ks_model = ssj.create_model([household, firm, mkt_clearing], name='Krusell-Smith')

calibration = {'eis': 0.5, 'delta': 0.025, 'alpha': 0.36, 'rho_e': 0.966, 'sd_e': 0.5, 'L': 1.0,
               'nE': 3, 'nA': 100, 'amin': 0, 'amax': 50, 'Z': 1.0, 'beta': 0.98}
unknowns_ss = {'K': 30.}
targets_ss = {'asset_mkt': 0.}

print("Solving for Steady State...")
ks_steady = ks_model.solve_steady_state(calibration, unknowns_ss, targets_ss, solver='hybr')
print(f"Steady State Capital (K): {ks_steady['K']:.4f}")

In [None]:
# Extract grids and transition matrices to Numpy
# We will convert these to PyTorch tensors shortly
ss_a_grid = ks_steady.internals['household']['a_grid']
ss_e_grid = ks_steady.internals['household']['e_grid']
ss_Pi_e = ks_steady.internals['household']['Pi'] # Transition matrix for idiosyncratic shock
ss_dist = ks_steady.internals['household']['D'].T  # Stationary distribution
ss_K = ks_steady['K']

# Define TFP Process (Aggregate Shock)
rho_tfp = 0.9
sigma_tfp = 0.007
n_tfp = 5
tfp_grid, pi_tfp, Pi_tfp = ssj.grids.markov_tauchen(rho=rho_tfp, sigma=sigma_tfp, N=n_tfp)
tfp_grid = np.exp(tfp_grid) # Levels, not logs

## Part 2: PyTorch Implementation (DEQN)

Here we define the Neural Network and the Physics of the model (Transition Dynamics) using PyTorch.

In [None]:
class KS_Network(nn.Module):
    def __init__(self, n_input, n_hidden=128, n_out=1):
        super().__init__()
        # Deep Equilibrium Net Architecture
        # Input: [Idiosyncratic State, Aggregate State]
        self.net = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.Mish(), # Mish activation helps with smooth gradients in Econ models
            nn.Linear(n_hidden, n_hidden),
            nn.Mish(),
            nn.Linear(n_hidden, n_hidden),
            nn.Mish(),
            nn.Linear(n_hidden, n_out)
        )
        
        # Initialize weights using Xavier uniform (good for Tanh/Sigmoid/Mish)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # Output is logit of savings rate
        return self.net(x)

In [None]:
class KrusellSmithModel:
    def __init__(self, calibration, steady_state_data, tfp_data, device):
        self.device = device
        
        # Unpack parameters
        self.beta = calibration['beta']
        self.gamma = 1 / calibration['eis']
        self.alpha = calibration['alpha']
        self.delta = calibration['delta']
        self.amin = calibration['amin']
        self.amax = calibration['amax']
        
        # Convert Grids to Tensors
        self.a_grid = torch.tensor(steady_state_data['a_grid'], dtype=torch.float32, device=device)
        self.e_grid = torch.tensor(steady_state_data['e_grid'], dtype=torch.float32, device=device)
        self.Pi_e = torch.tensor(steady_state_data['Pi_e'], dtype=torch.float32, device=device)
        
        # TFP Process
        self.tfp_grid = torch.tensor(tfp_data['grid'], dtype=torch.float32, device=device)
        self.Pi_tfp = torch.tensor(tfp_data['Pi'], dtype=torch.float32, device=device)
        
        # Dimensions
        self.nA = len(self.a_grid)
        self.nE = len(self.e_grid)
        self.nTFP = len(self.tfp_grid)

    def get_prices(self, K, Z):
        # Cobb-Douglas Production
        r = self.alpha * Z * K.pow(self.alpha - 1) - self.delta
        w = (1 - self.alpha) * Z * K.pow(self.alpha)
        return r, w

    def utility_prime(self, c):
        # CRRA Marginal Utility
        return torch.pow(c + 1e-7, -self.gamma)

    def utility_prime_inv(self, u_prime):
        # Inverse Marginal Utility
        return torch.pow(u_prime + 1e-7, -1.0 / self.gamma)

    def fisher_burmeister(self, a, b):
        # NCP function for Kuhn-Tucker conditions
        # f(a,b) = a + b - sqrt(a^2 + b^2)
        return a + b - torch.sqrt(a**2 + b**2 + 1e-7)

    def compute_savings(self, network, agg_state, individual_state_indices):
        ''' 
        Computes optimal savings given the Neural Network Policy.
        agg_state: [TFP, Distribution Moments/Full Dist]
        individual_state_indices: Tuple (idx_a, idx_e)
        '''
        # 1. Unpack State
        idx_a, idx_e = individual_state_indices
        a_val = self.a_grid[idx_a] # (Batch)
        e_val = self.e_grid[idx_e] # (Batch)
        
        # 2. Prepare NN Input
        # Normalize inputs slightly to help training
        norm_a = (a_val - self.amin) / (self.amax - self.amin)
        # Input Vector: [a, e, Aggregate_State...]
        # Agg state is usually [Z, K]
        nn_input = torch.cat([
            norm_a.unsqueeze(-1), 
            e_val.unsqueeze(-1), 
            agg_state
        ], dim=1)
        
        # 3. Get Savings Rate from NN
        raw_output = network(nn_input)
        savings_rate = torch.sigmoid(raw_output).squeeze()

        # 4. Budget Constraint
        # We need prices r and w. 
        # Assuming agg_state[:,0] is Z and agg_state[:,1] is K
        Z = agg_state[:, 0]
        K = agg_state[:, 1]
        r, w = self.get_prices(K, Z)
        
        coh = (1 + r) * a_val + w * e_val
        
        # Policy
        sav = self.amin + savings_rate * (coh - self.amin)
        con = coh - sav
        
        return sav, con

    def transport_distribution(self, sav_policy, dist_mass):
        '''
        Moves the histogram forward.
        sav_policy: (Batch, nA, nE) - Savings choices for every grid point
        dist_mass: (Batch, nA, nE) - Current mass at every grid point
        '''
        BatchSize = sav_policy.shape[0]
        
        # Flatten asset/productivity dims for processing
        sav_flat = sav_policy.reshape(BatchSize, -1)
        mass_flat = dist_mass.reshape(BatchSize, -1)
        
        # 1. Find indices on Asset Grid (Linear Interpolation logic)
        # searchsorted expects 1D grid, we apply to each batch element
        # Since grid is constant, we can use bucketize logic
        idx_upper = torch.searchsorted(self.a_grid, sav_flat)
        idx_upper = torch.clamp(idx_upper, 1, self.nA - 1)
        idx_lower = idx_upper - 1
        
        a_lower = self.a_grid[idx_lower]
        a_upper = self.a_grid[idx_upper]
        
        # Weights for interpolation
        weight_upper = (sav_flat - a_lower) / (a_upper - a_lower + 1e-8)
        weight_lower = 1.0 - weight_upper
        
        # 2. Prepare for Markov Transition (e -> e')
        # The input `dist_mass` is at (a, e). We move it to a'. 
        # Then we must split it among e' based on Pi_e.
        
        # Initialize next period distribution
        dist_next = torch.zeros(BatchSize, self.nA * self.nE, device=self.device)
        
        # Current indices (a, e) flattened: e changes every 1 step, a every nE steps?
        # No, usually setup is [nA, nE]. a changes slowly, e fast? Or vice versa.
        # Let's assume sav_policy is [Batch, nA, nE]. 
        # We need to map the mass to specific indices in the output vector.
        
        # Construct base indices for the `e` dimension
        # For a specific source `e`, we distribute to all `e_prime`
        
        for e_from in range(self.nE):
            # Select mass coming from state e_from
            # Shape: [Batch, nA]
            mass_slice = dist_mass[:, :, e_from].reshape(BatchSize, -1)
            idx_l_slice = idx_lower.view(BatchSize, self.nA, self.nE)[:, :, e_from]
            idx_u_slice = idx_upper.view(BatchSize, self.nA, self.nE)[:, :, e_from]
            w_l_slice = weight_lower.view(BatchSize, self.nA, self.nE)[:, :, e_from]
            w_u_slice = weight_upper.view(BatchSize, self.nA, self.nE)[:, :, e_from]
            
            # Loop over destination e_to
            for e_to in range(self.nE):
                prob = self.Pi_e[e_from, e_to]
                if prob == 0: continue
                
                # Destination global indices in the flattened array (nA * nE)
                # Structure: [a0e0, a0e1... a1e0...]
                # So index = a_idx * nE + e_to
                dest_idx_l = idx_l_slice * self.nE + e_to
                dest_idx_u = idx_u_slice * self.nE + e_to
                
                # Add Mass
                dist_next.scatter_add_(1, dest_idx_l, mass_slice * w_l_slice * prob)
                dist_next.scatter_add_(1, dest_idx_u, mass_slice * w_u_slice * prob)
                
        return dist_next.view(BatchSize, self.nA, self.nE)

## Part 3: The Training Loop (Cloud Simulation)

We simulate a "cloud" of $N$ parallel economies. 
1. **Simulate:** Given current states, calculate next states (Aggregate K and TFP).
2. **Loss:** Calculate Euler equation errors for random agents inside these economies.

In [None]:
def train_model(model_physics, n_epochs=2000, batch_size=256, n_agents_sample=64):
    
    # Initialize Network
    # Input features: a (1) + e (1) + Z (1) + K (1) = 4 features
    # Note: Detailed DEQN often uses more moments of dist, but for KS, mean K is sufficient.
    net = KS_Network(n_input=4, n_hidden=64).to(device)
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
    
    # Initialize Cloud of Economies
    n_cloud = batch_size
    
    # Initial Distribution (Start at Steady State)
    # Shape: [Cloud, nA, nE]
    current_dist = torch.tensor(ss_dist, dtype=torch.float32, device=device).unsqueeze(0).expand(n_cloud, -1, -1).clone()
    
    # Initial TFP (Randomly sampled)
    current_tfp_idx = torch.randint(0, model_physics.nTFP, (n_cloud,), device=device)
    
    loss_history = []
    
    print("Starting Training...")
    start_time = time.time()
    
    for epoch in range(n_epochs):
        
        optimizer.zero_grad()
        
        # --- 1. Construct Aggregate State ---
        # Calc Aggregate Capital K
        # Sum(Mass * Asset_Grid)
        # Dist: [Batch, nA, nE], Grid: [nA]
        # Sum over nE, then dot product with Grid
        marg_dist_a = current_dist.sum(dim=2) # [Batch, nA]
        K_agg = (marg_dist_a * model_physics.a_grid).sum(dim=1) # [Batch]
        
        Z_curr = model_physics.tfp_grid[current_tfp_idx]
        
        agg_state = torch.stack([Z_curr, K_agg], dim=1) # [Batch, 2]
        
        # --- 2. Compute Euler Error (Loss) ---
        # We sample random agents (a, e) to check their Euler errors
        
        # Sample indices
        idx_a = torch.randint(0, model_physics.nA, (n_cloud, n_agents_sample), device=device)
        idx_e = torch.randint(0, model_physics.nE, (n_cloud, n_agents_sample), device=device)
        
        # Expand agg_state for these agents
        # agg_state_exp: [Batch, Agents, 2]
        agg_state_exp = agg_state.unsqueeze(1).expand(-1, n_agents_sample, -1)
        
        # Flatten for NN processing
        flat_idx_a = idx_a.reshape(-1)
        flat_idx_e = idx_e.reshape(-1)
        flat_agg = agg_state_exp.reshape(-1, 2)
        
        # Get Current Policies (c, a')
        sav, con = model_physics.compute_savings(net, flat_agg, (flat_idx_a, flat_idx_e))
        
        # Calculate Marginal Utility u'(c)
        mu = model_physics.utility_prime(con)
        
        # --- 3. Expectations (t+1) ---
        # We need to forecast K' and Z'
        # For K': We need to simulate the ENTIRE distribution forward
        # This is the expensive part. We do it inside `no_grad` usually for simulation,
        # but for loss calculation involving K', we technically need gradients if we want to solve for price consistency perfectly.
        # However, DEQN standard approach: Take K' as given by the simulation step (Fixed Point iteration style).
        
        with torch.no_grad():
            # Compute policy for ALL grid points to move distribution
            # Create full grid inputs [Batch, nA, nE, 2]
            full_agg = agg_state.view(n_cloud, 1, 1, 2).expand(-1, model_physics.nA, model_physics.nE, -1)
            full_a_idx = torch.arange(model_physics.nA, device=device).view(1, -1, 1).expand(n_cloud, -1, model_physics.nE)
            full_e_idx = torch.arange(model_physics.nE, device=device).view(1, 1, -1).expand(n_cloud, model_physics.nA, -1)
            
            flat_full_agg = full_agg.reshape(-1, 2)
            flat_full_a = full_a_idx.reshape(-1)
            flat_full_e = full_e_idx.reshape(-1)
            
            sav_grid_flat, _ = model_physics.compute_savings(net, flat_full_agg, (flat_full_a, flat_full_e))
            sav_grid = sav_grid_flat.view(n_cloud, model_physics.nA, model_physics.nE)
            
            # Move Distribution
            next_dist = model_physics.transport_distribution(sav_grid, current_dist)
            
            # Calculate K_prime
            marg_dist_next = next_dist.sum(dim=2)
            K_next = (marg_dist_next * model_physics.a_grid).sum(dim=1)
            
            # Transition TFP (Sample next Z)
            # For Euler error expectation, we sum over all possible Z_prime probabilities.
        
        # Calculate RHS of Euler: beta * E [ (1+r') * u'(c') ]
        rhs_expectation = torch.zeros_like(mu)
        
        # Loop over possible future TFP states (Integration over aggregate shock)
        for z_next_idx in range(model_physics.nTFP):
            prob_z = model_physics.Pi_tfp[current_tfp_idx, z_next_idx] # [Batch]
            # Mask for efficiency
            # (In simple code, we just compute all)
            
            Z_next = model_physics.tfp_grid[z_next_idx].repeat(n_cloud)
            r_next, _ = model_physics.get_prices(K_next, Z_next)
            
            # Integration over Idiosyncratic Shock (e')
            # The agent at (a, e) saved `sav`. In t+1, they have assets `sav`.
            # Their `e` transitions to `e_prime`.
            
            # We need c_prime for the agent. 
            # State: Assets=sav (continuous), e=e_prime, Agg=(Z_next, K_next)
            
            # Expand for agents
            r_next_exp = r_next.unsqueeze(1).expand(-1, n_agents_sample).reshape(-1)
            Z_next_exp = Z_next.unsqueeze(1).expand(-1, n_agents_sample).reshape(-1)
            K_next_exp = K_next.unsqueeze(1).expand(-1, n_agents_sample).reshape(-1)
            
            # Current savings become next period assets
            a_prime = sav.detach() # [Batch * Agents]
            # Note: Inputs to NN for 'a' need to be on grid? 
            # No, NN takes continuous values. But our wrapper `compute_savings` assumed indices.
            # We need a raw access method for continuous 'a'.
            
            # -- Inline continuous evaluation for t+1 --
            norm_a_prime = (a_prime - model_physics.amin) / (model_physics.amax - model_physics.amin)
            agg_next_exp = torch.stack([Z_next_exp, K_next_exp], dim=1)
            
            # We must sum over e_prime
            expected_val_given_z = 0
            
            for e_prime_idx in range(model_physics.nE):
                # Prob of e -> e'q
                prob_e = model_physics.Pi_e[flat_idx_e, e_prime_idx] # [Batch*Agents]
                
                e_prime_val = model_physics.e_grid[e_prime_idx].repeat(n_cloud * n_agents_sample)
                
                nn_input_next = torch.cat([
                    norm_a_prime.unsqueeze(-1),
                    e_prime_val.unsqueeze(-1),
                    agg_next_exp
                ], dim=1)
                
                # Get c_prime
                out_next = net(nn_input_next)
                sav_rate_next = torch.sigmoid(out_next).squeeze()
                
                # Budget t+1
                # w_next depends on Z_next, K_next
                _, w_next = model_physics.get_prices(K_next_exp, Z_next_exp)
                
                coh_next = (1 + r_next_exp) * a_prime + w_next * e_prime_val
                c_prime = coh_next - (model_physics.amin + sav_rate_next * (coh_next - model_physics.amin))
                
                mu_prime = model_physics.utility_prime(c_prime)
                
                expected_val_given_z += prob_e * (1 + r_next_exp) * mu_prime
            
            # Add to total expectation (weighted by Z prob)
            prob_z_exp = prob_z.unsqueeze(1).expand(-1, n_agents_sample).reshape(-1)
            rhs_expectation += prob_z_exp * expected_val_given_z
            
        rhs = model_physics.beta * rhs_expectation
        
        # Fischer-Burmeister Error (accounts for borrowing constraint)
        # f(Euler, Constraint)
        # Euler Residual: 1 - RHS/LHS
        # Or simpler: u'(c) - beta * E...
        
        diff = 1.0 - rhs / mu
        loss = torch.mean(diff**2)
        
        # --- 4. Optimization ---
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        loss_history.append(loss.item())
        
        # --- 5. Simulation Step (Update Cloud State) ---
        # We effectively did this in step 3 to get K_next. 
        # Now we just update the pointers for the next loop.
        current_dist = next_dist.detach()
        
        # Sample actual next Z for the cloud
        # Use vectorized multinomial sampling
        probs = model_physics.Pi_tfp[current_tfp_idx] # [Batch, nTFP]
        current_tfp_idx = torch.multinomial(probs, 1).squeeze()
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.6f} | K_agg: {K_agg.mean().item():.3f}")

    print(f"Training Complete. Final Loss: {loss.item():.6f}")
    return net, loss_history, current_dist


In [None]:
# Setup Physics
ks_ss_data = {'a_grid': ss_a_grid, 'e_grid': ss_e_grid, 'Pi_e': ss_Pi_e}
tfp_data = {'grid': tfp_grid, 'Pi': Pi_tfp}

model = KrusellSmithModel(calibration, ks_ss_data, tfp_data, device)

# Run Training
trained_net, history, final_dist = train_model(model, n_epochs=2000)

## Part 4: Analysis and Plotting

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(np.log10(history))
plt.title("Training Loss (Log10)")
plt.xlabel("Epochs")
plt.ylabel("Log Loss")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Check the Policy Function
def plot_policy(net, model, k_level_idx=2, z_level_idx=2):
    # Setup inputs
    a_vals = model.a_grid
    e_vals = model.e_grid
    
    # Fix Aggregate State
    Z = model.tfp_grid[z_level_idx]
    K = torch.tensor([ss_K], device=device) # Use steady state K
    
    agg_vec = torch.stack([Z, K]).view(1, 2).expand(len(a_vals), 2)
    
    plt.figure(figsize=(10, 6))
    
    with torch.no_grad():
        for i, e in enumerate(e_vals):
            # Create inputs
            norm_a = (a_vals - model.amin) / (model.amax - model.amin)
            nn_in = torch.cat([norm_a.unsqueeze(1), 
                               e.repeat(len(a_vals)).unsqueeze(1),
                               agg_vec], dim=1)
            
            sav_rate = torch.sigmoid(net(nn_in)).squeeze()
            
            # Budget
            r, w = model.get_prices(K, Z)
            coh = (1+r)*a_vals + w*e
            sav = model.amin + sav_rate * (coh - model.amin)
            
            plt.plot(a_vals.cpu(), sav.cpu(), label=f"Productivity e={e:.2f}")
            
    plt.plot(a_vals.cpu(), a_vals.cpu(), 'k--', alpha=0.5, label="45 degree")
    plt.xlabel("Assets (a)")
    plt.ylabel("Savings (a')")
    plt.title(f"Savings Policy Function (at K={ss_K:.2f}, Z={Z:.2f})")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

plot_policy(trained_net, model)