In [1]:
import numpy as np
import pandas as pd

In [2]:
import seaborn as sns
import matplotlib.pylab as plt

In [3]:
import torch
import torch.nn.functional as F
import pyro

In [4]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [5]:
from monteloanco import model, guide, LearnedTransition, Template

### A deep state-space model for a consumer credit risk portfolio

This notebook outlines the development of a deep state-space model for consumer credit risk, built using [pyro.ai](https://pyro.ai/). At its core, the model employs Monte Carlo simulations for each loan, progressing through monthly timesteps. The hidden state at each step represents the loan’s status, with all accounts initially starting as current. From there, loans may transition to early payoff, arrears, or more commonly, remain current and advance to the next month.

The model requires 5 inputs: 
- `loan_amnt` the initial advance to the customer.
- `int_rate` the annual interest rate (as a percentage).
- `installment` the monthly payment according to the initial schedule.
- `total_pre_chargeoff` the total value of payments made against the account excluding recoveries.
- `num_timesteps` the number of months observed to date if training, or the desired length of the simulation.

The output used for validation is a simulation of hidden states (loan statuses) and payments, plus how those payments are attributed to principal and interest. Behind the scenes, the model also trains an embedding based on the loan account identifier, which effectively captures the performance characteristics of each specific loan. This embedding may serve several purposes, including:
- Simulating the performance of the existing portfolio.
- Extending the installment schedule to maturity to estimate the portfolio’s value if allowed to run off.
- Providing a low-dimensional representation of loan performance, enabling broader analysis beyond traditional good/bad account classifications for training applicant-level models.
- Reducing to a single risk dimension that represents the probability of default over any given time horizon.

We take a subset of the 2+ million accounts available here for speed.

In [6]:
df_train = pd.read_json('training.jsonl.gz', lines=True)
pd.testing.assert_index_equal(df_train.index, pd.RangeIndex(0, len(df_train)))

The model has been designed such that it can train / simulate a large number of accounts in parallel on a GPU. If you don't have a suitable GPU installed on your machine simply replace `cuda:0` here with `cpu`.

The data must be fed into the model in batches, where all sequences in a batch have the same length. We use the custom`GroupedBatchSampler` to define these batches.

In [7]:
embedding_size = 3
device = 'cuda:0'

In [8]:
batch_size = 100_000
dataset = df_train[['id', 'loan_amnt', 'int_rate', 'installment', 'n_report_d', 'total_pre_chargeoff', 'last_pymnt_amnt']].to_dict(orient='records')
dataset.__getitem__(42)

{'id': 79661304,
 'loan_amnt': 10000,
 'int_rate': 7.39,
 'installment': 310.56,
 'n_report_d': 33,
 'total_pre_chargeoff': 10008.44,
 'last_pymnt_amnt': 10012.55}

### Train the model

With the batches defined it's time to run the optimisation process, and tune the parameters. The loss here is the difference between the the total value of payments made on each account vs. those from the MC simulation.

In [9]:
%%time

# Clear the param store in case we're in a REPL
pyro.clear_param_store()

# Create partial functions with their respective parameters
provider = LearnedTransition(
    name="tmat_logits",
    init_logits=Template.DEMO_LOGITS,
    trainable=True,
    offset_scale=1.0,
    total_size=100_000
)

# Set up the optimizer and inference algorithm
optimizer = pyro.optim.Adam({"lr": 0.001})
svi = pyro.infer.SVI(model=model, guide=guide,
                     optim=optimizer, loss=pyro.infer.Trace_ELBO())

# Run inference
num_iterations = 50_000

with tqdm(total=num_iterations, desc="Epochs", position=0) as epoch_pbar:
    for step in range(num_iterations):
        losses = []
        for batch_id, batch in enumerate(DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)):
            loss = svi.step(
                batch_id=batch_id,
                batch_idx=torch.arange(len(batch['id'])).to(device),
                installments=batch['installment'].to(device), 
                loan_amnt=batch['loan_amnt'].to(device), 
                int_rate=batch['int_rate'].to(device),
                tmat_provider=provider,
                total_pre_chargeoff=batch['total_pre_chargeoff'].to(device),
                last_pymnt_amnt=batch['last_pymnt_amnt'].to(device),
                num_timesteps=batch['n_report_d'].to(device)  # Changed: pass the entire vector!
            )
            losses.append(loss)
            
        if step % np.ceil(num_iterations/100) == 0:
            print(f"Step {step} : Loss = {np.sum(losses)}")
        epoch_pbar.update(1)

Epochs:   0%|          | 0/50000 [00:00<?, ?it/s]

Step 0 : Loss = 597236444.9172
Step 500 : Loss = 464187869.89796084
Step 1000 : Loss = 427266235.5993371
Step 1500 : Loss = 392869875.03837883
Step 2000 : Loss = 354731807.48382545
Step 2500 : Loss = 322873076.83136064
Step 3000 : Loss = 292455793.3325752
Step 3500 : Loss = 270402432.66654205
Step 4000 : Loss = 251354314.38665706
Step 4500 : Loss = 234693202.05227005
Step 5000 : Loss = 221575866.61705273
Step 5500 : Loss = 208854819.28905648
Step 6000 : Loss = 200501124.80334136
Step 6500 : Loss = 189599764.52852094
Step 7000 : Loss = 183900422.06435233
Step 7500 : Loss = 176093928.95018524
Step 8000 : Loss = 171307071.50657308
Step 8500 : Loss = 163472202.97959086
Step 9000 : Loss = 157281545.45482084
Step 9500 : Loss = 155357648.2534263
Step 10000 : Loss = 149801962.13883007
Step 10500 : Loss = 148391317.95101106
Step 11000 : Loss = 148100569.6280383
Step 11500 : Loss = 145942569.53204912
Step 12000 : Loss = 143553790.92743838
Step 12500 : Loss = 141701444.11966205
Step 13000 : Loss 

### Save the model

Save model parameters to a file for inference in another notebook.

In [10]:
pyro.get_param_store().save('param_store.pt')