In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from notebook_helpers import RandomChunkDataset
import pandas as pd
import math 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# VI (Categorical) + HMM ELBO â€” Structured by Components

## Encoder (produces logits and posterior)

At each time step \(t\), the encoder outputs logits $$\ell_t \in \mathbb{R}^K$$ and a categorical posterior:

$$
q_\phi(z_t = k \mid x_{1:T}) = \mathrm{softmax}(\ell_t)_k
\quad\Rightarrow\quad
q_{t,k} = \frac{e^{\ell_{t,k}}}{\sum_{j=1}^K e^{\ell_{t,j}}}.
$$

Collecting over time (mean-field form):

$$
q_\phi(z_{1:T} \mid x_{1:T}) = \prod_{t=1}^T q_\phi(z_t \mid x_{1:T}), 
\qquad
q \in \mathbb{R}^{T \times K}.
$$

---

## Prior (HMM: initial distribution and transitions)

Initial state distribution:

$$
p_\theta(z_1 = k) = \pi_k, 
\qquad 
\sum_{k=1}^K \pi_k = 1.
$$

Transitions (stationary or input-conditioned):

$$
p_\theta(z_t = j \mid z_{t-1} = i,\, u_{1:T}) = A_t[i,j],
\qquad
\sum_{j=1}^K A_t[i,j] = 1 \ \text{for each row } i.
$$

Mean-field expected log prior under \(q\):

$$
\mathbb{E}_q[\log p_\theta(z_{1:T} \mid u_{1:T})]
=
\sum_{k=1}^K q_{1,k}\, \log \pi_k
\;+\;
\sum_{t=2}^T \sum_{i=1}^K \sum_{j=1}^K q_{t-1,i}\, q_{t,j}\, \log A_t[i,j].
$$

Parameterization via logits (normalization):

$$
\pi = \mathrm{softmax}(\alpha), \quad \alpha \in \mathbb{R}^K,
\qquad
A_t[i,\cdot] = \mathrm{softmax}\big(M_t[i,\cdot]\big).
$$

Stationary:
$$
M_t \equiv M \in \mathbb{R}^{K\times K}.
$$

Input-conditioned: 
$$
M_t = g_\theta^{\text{trans}}(u_t) \in \mathbb{R}^{K\times K}.
$$

---

## Decoder (Gaussian emissions via state embedding)

State embedding matrix $$E \in \mathbb{R}^{K \times D}$$ and per-time embedding:

$$
e_t = q_t^\top E \in \mathbb{R}^D,
\qquad
e \in \mathbb{R}^{T \times D}.
$$

Emission parameters (diagonal Gaussian):

$$
(\mu_t,\, \log \sigma_t^2) = g_\theta(e_t),
\qquad
\mu_t,\, \sigma_t \in \mathbb{R}^d.
$$

Per-time log likelihood (diagonal Gaussian):

$$
\log p_\theta(x_t \mid z_t) \approx \log \mathcal{N}\!\big(x_t;\, \mu_t,\, \mathrm{diag}(\sigma_t^2)\big)
= -\tfrac{1}{2}\!\left[
d\log(2\pi)
+ \sum_{j=1}^d \log \sigma^2_{t,j}
+ \sum_{j=1}^d \frac{(x_{t,j}-\mu_{t,j})^2}{\sigma^2_{t,j}}
\right].
$$

Expected reconstruction term under \(q\):

$$
\mathbb{E}_q[\log p_\theta(x_{1:T} \mid z_{1:T})]
=
\sum_{t=1}^T \sum_{k=1}^K q_{t,k}\, \log p_\theta(x_t \mid z_t = k)
\;\;\approx\;\;
\sum_{t=1}^T \log \mathcal{N}\!\big(x_t;\, \mu_t,\, \mathrm{diag}(\sigma_t^2)\big).
$$

(The approximation uses $e_t = q_t^\top E$ to condition $\mu_t,\sigma_t$ directly on $q_t$.)

---

## Main Model (ELBO, entropy, and loss)

Entropy (sum over time and states):

$$
-\mathbb{E}_q[\log q_\phi(z_{1:T} \mid x_{1:T})]
=
\sum_{t=1}^T \sum_{k=1}^K \big(- q_{t,k}\, \log q_{t,k}\big).
$$

Full ELBO:

$$
\mathcal{L}(\theta,\phi)
=
\underbrace{\mathbb{E}_q[\log p_\theta(x_{1:T} \mid z_{1:T})]}_{\text{reconstruction}}
+
\underbrace{\mathbb{E}_q[\log p_\theta(z_{1:T} \mid u_{1:T})]}_{\text{HMM prior}}
-
\underbrace{\mathbb{E}_q[\log q_\phi(z_{1:T} \mid x_{1:T})]}_{\text{entropy}}.
$$

Optional $\beta$-weighting (warm-up):

$$
\mathcal{L}_\beta(\theta,\phi)
=
\mathrm{Recon}
+
\beta \,(\mathrm{Prior} - \mathrm{Entropy}),
\qquad
\beta \in [0,1].
$$

Training objective (minimize negative ELBO):

$$
\mathcal{J}(\theta,\phi) = -\,\mathcal{L}_\beta(\theta,\phi).
$$

---

## Notes on symbols

- $K$: number of discrete hidden states (regimes).
- $T$: number of time steps.
- $d$: data dimension per time step.
- $q_{t,k}$: variational posterior probability of state $k$ at time $t$.
- $\pi$: initial state distribution (with $\sum_{k=1}^K \pi_k = 1$).
- $A_t$: transition probabilities at time $t$ (each row sums to $1$).
- $E \in \mathbb{R}^{K\times D}$: state embedding matrix; $D$ is embedding size.
- $g_\theta$: decoder network producing $(\mu_t,\log\sigma_t^2)$ from $e_t$.
- $u_t$: optional exogenous inputs for input-conditioned transitions.

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, hidden_dim2, K):
        super().__init__()
        # x = (batch_size, input_dim, T)
        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim2, kernel_size=3, padding=1)
        self.to_logits = nn.Conv1d(hidden_dim2, K, kernel_size=1)
        # logits = (batch_size, K, T)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        logits = self.to_logits(x)
        return logits

class Prior(nn.Module):
    def __init__(self, K, u_dim=None, trans_hidden=128):
        super().__init__()
        self.k_clusters = K
        self.u_dim = u_dim
        # initial state logits (unnormalized) with softmax -> pi (normalized state distribtuion)
        self.log_prior = nn.Parameter(torch.zeros(K), requires_grad=True)

        if u_dim is None:
            raise ValueError('Not supporting stationary transitions')
        else:
            # input-conditioned transitions: small MLP maps u_t -> K*K logits
            self.transition_net = nn.Sequential(
                nn.Linear(u_dim, trans_hidden),
                nn.ReLU(),
                nn.Linear(trans_hidden, K * K)
            )

    def forward(self, u=None):
        # pi = initial state distribution
        log_pi = F.log_softmax(self.log_prior, dim=-1)

        if self.u_dim is None or u is None:
            raise ValueError('Not supporting stationary transitions')

        # input-conditioned case
        if u.dim() == 3 and u.shape[1] == self.u_dim:
            # (B, U_dim, T) -> (B, T, U_dim)
            u = u.permute(0, 2, 1)

        B, T, U = u.shape
        u_flat = u.reshape(B * T, U)
        logits = self.transition_net(u_flat) # (B*T, K*K)
        logits = logits.view(B, T, self.k_clusters, self.k_clusters)
        # normalize by row 
        log_A = F.log_softmax(logits, dim=-1)
        return log_pi, log_A

class Decoder(nn.Module):
    def __init__(self, K, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.K = K
        self.latent_dim = latent_dim
        self.E = nn.Embedding(K, latent_dim)  # embedding for each discrete state
        self.conv1 = nn.Conv1d(latent_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.to_output = nn.Conv1d(hidden_dim, output_dim*2, kernel_size=1)

    def forward(self, q):
        # q: (B, T) discrete latent states
        B, K, T = q.shape
        e = torch.matmul(q.permute(0, 2, 1), self.E.weight)
        e_t = e.permute(0, 2, 1)

        x = self.conv1(e_t)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        out = self.to_output(x)

        C = out.shape[1] // 2
        mu = out[:, :C, :]
        logvar = out[:, C:, :]
        return mu, logvar
    

class VAE_HMM(nn.Module):
    def __init__(self, input_dim, hidden_dim, K, hidden_dim2, u_dim=None, trans_hidden=128):
        super().__init__()
        self.encoder = Encoder(input_dim, hidden_dim, hidden_dim2, K)
        self.prior = Prior(K, u_dim=u_dim, trans_hidden=trans_hidden)
        self.decoder = Decoder(K, latent_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=input_dim)
        self.K = K
    
    def encode(self, x):
        logits = self.encoder(x)
        return logits
    
    def decode(self, q):
        return self.decoder(q)
    
    def compute_loss(self, x, u=None, lengths=None, beta = 1.0):
        # ELBO loss = reconstruction loss + HMM prior loss - entropy of q -> minimize negative elbo

        B, C, T = x.shape
        # mask of valid timesteps (B: batch size, T)
        if lengths is None:
            raise ValueError('lengths must be provided')
        mask = (torch.arange(T, device=x.device)[None, :] < lengths[:, None].to(x.device))

        log_pi, log_A = self.prior(u)
        logits = self.encoder(x)  # (B, K, T)
        q_probs = F.softmax(logits, dim=1)  # (B, K, T)
        mu, logvar = self.decode(q_probs)

        # reconstruction loss with negative gaussian log likelihood mse
        recon_mask = mask.to(dtype=x.dtype).unsqueeze(1)
        squared = ((mu - x) ** 2)
        var = logvar.exp()
        neg_log_likelihood = 0.5 * (torch.log(2 * math.pi * var) + squared/var) #
        nll_masked = (neg_log_likelihood * recon_mask).sum()
        num_valid = (mask.sum() * x.shape[1]).clamp(min=1.0)
        recon_loss = nll_masked / num_valid

        # HMM prior loss
        prior_loss = 0.0
        q1 = q_probs[:, :, 0] 
        initial_term = (q1 * log_pi.unsqueeze(0)).sum(dim=1)

        # reshape to get q_t-1 and q_t
        q_root = q_probs[:, :, :-1]
        q_transition = q_probs[:, :, 1:]

        # reshape to (B, T-1, K, 1) and (B, T-1, 1, K) and (B, T-1, K, K)
        q_root = q_root.permute(0, 2, 1).unsqueeze(-1)
        q_transition = q_transition.permute(0, 2, 1).unsqueeze(-2)
        logA_bt = log_A[:, 1:, :, :]

        joint = q_root * q_transition * logA_bt # (B, T-1, K, K)
        trans_terms_bt = joint.sum(dim=(2, 3))
        trans_mask = (mask[:, 1:] & mask[:, :-1]).float()
        trans_term = (trans_terms_bt * trans_mask).sum(dim=1) 

        prior_per_batch = initial_term + trans_term
        prior_loss = - prior_per_batch.mean()

        # entropy (sum over time and states (T, K) -q_t,k log q_t,k) averaged over batch, for only valid timesteps
        log_q = F.log_softmax(logits, dim=1) # (batch size, K, T)
        # compute per batch
        per_bt_entropy = - (q_probs * log_q).sum(dim=1)
        entropy = (per_bt_entropy * mask.float()).sum() / B

        return recon_loss + beta * (prior_loss - entropy)


    def forward(self, x):
        logits = self.encode(x)
        q = F.softmax(logits, dim=1)
        recon_x = self.decode(q)
        return recon_x, q

In [45]:
def train_model(model, dataloader, num_epochs=10, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        # increase beta for first half of epochs gradually to 1
        beta = min(1.0, 2.0 * (epoch + 1) / num_epochs)
        for batch in dataloader:
            x, u, lengths = batch

            optimizer.zero_grad()
            loss = model.compute_loss(x, u, lengths, beta)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    return model

In [46]:
def collate_fn(batch):
    x_seqs = [item[0] for item in batch]
    u_seqs = [item[1] for item in batch]

    lengths = torch.tensor([int(item[2]) for item in batch], dtype=torch.long)
    max_len = lengths.max().item()
    B = len(batch)
    C = x_seqs[0].shape[0]
    U = u_seqs[0].shape[0]

    x_batch = torch.zeros(B, C, max_len, device=device)
    u_batch = torch.zeros(B, U, max_len, device=device)

    for i in range(B):
        L = lengths[i].item()
        x_batch[i, :, :L] = x_seqs[i]
        u_batch[i, :, :L] = u_seqs[i]

    return x_batch, u_batch, lengths

In [47]:
x_data = pd.read_csv('train_dataset_scaled.csv').drop(columns=['date']).values
u_data = pd.read_csv('train_dataset_scaled.csv').drop(columns=['date', 'historical_vol']).values

x_sequences = torch.tensor(x_data, dtype=torch.float)
u_sequences = torch.tensor(u_data, dtype=torch.float)

x_sequences = [x_sequences.permute(1,0)]
u_sequences = [u_sequences.permute(1,0)]


dataset = RandomChunkDataset(x_sequences, u_sequences, min_len=20, max_len=200)

In [67]:
# create dataloader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn, drop_last=False)
u_dim = u_sequences[0].shape[0]
C = x_sequences[0].shape[0]
num_clusters = 3
hidden0 = 64
hidden1 = 32

model = VAE_HMM(input_dim=C, hidden_dim=hidden0, K=num_clusters, hidden_dim2=hidden1, u_dim=u_dim, trans_hidden=64)
trained = train_model(model, dataloader, num_epochs=150, lr=1e-5)

Epoch 1/150, Loss: 1.4494
Epoch 2/150, Loss: 1.4697
Epoch 2/150, Loss: 1.4697
Epoch 3/150, Loss: 1.4930
Epoch 3/150, Loss: 1.4930
Epoch 4/150, Loss: 1.5040
Epoch 4/150, Loss: 1.5040
Epoch 5/150, Loss: 1.4767
Epoch 5/150, Loss: 1.4767
Epoch 6/150, Loss: 1.4756
Epoch 6/150, Loss: 1.4756
Epoch 7/150, Loss: 1.4798
Epoch 7/150, Loss: 1.4798
Epoch 8/150, Loss: 1.4467
Epoch 8/150, Loss: 1.4467
Epoch 9/150, Loss: 1.4524
Epoch 9/150, Loss: 1.4524
Epoch 10/150, Loss: 1.4412
Epoch 10/150, Loss: 1.4412
Epoch 11/150, Loss: 1.3972
Epoch 11/150, Loss: 1.3972
Epoch 12/150, Loss: 1.3724
Epoch 12/150, Loss: 1.3724
Epoch 13/150, Loss: 1.3520
Epoch 13/150, Loss: 1.3520
Epoch 14/150, Loss: 1.3129
Epoch 14/150, Loss: 1.3129
Epoch 15/150, Loss: 1.2590
Epoch 15/150, Loss: 1.2590
Epoch 16/150, Loss: 1.2098
Epoch 16/150, Loss: 1.2098
Epoch 17/150, Loss: 1.1527
Epoch 17/150, Loss: 1.1527
Epoch 18/150, Loss: 1.0776
Epoch 18/150, Loss: 1.0776
Epoch 19/150, Loss: 1.0164
Epoch 19/150, Loss: 1.0164
Epoch 20/150, Loss

In [68]:
# save encoder path after training
encoder_path = 'encoder_saved.pth'
torch.save({'model_state_dict': trained.encoder.state_dict(),
            'config': {
                'input_dim': C,
                'hidden_dim': hidden0,
                'hidden_dim2': hidden1,
                'K': num_clusters
            }
           }, encoder_path)

In [None]:
import yfinance as yf
import numpy as np

end_day = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

x_data = pd.read_csv('test_data.csv').drop(columns=['Date']).values[:end_day]
u_data = pd.read_csv('test_data.csv').drop(columns=['Date', 'historical_vol']).values[:end_day]

ckpt = torch.load('encoder_saved.pth', map_location='cpu')
enc_cfg = ckpt.get('config', {})
input_dim = enc_cfg.get('input_dim', C)
hidden_dim = enc_cfg.get('hidden_dim', 32)
hidden_dim2 = enc_cfg.get('hidden_dim2', hidden_dim)
K = enc_cfg.get('K', num_clusters)

# Recreate encoder architecture
encoder_loaded = Encoder(input_dim, hidden_dim, hidden_dim2, K)
encoder_loaded.load_state_dict(ckpt['model_state_dict'])
encoder_loaded.to(device)
encoder_loaded.eval()

x_tensor = torch.tensor(x_data, dtype=torch.float).permute(1, 0).unsqueeze(0)  # (1, C, T)
x_tensor = x_tensor.to(device)

encoder_loaded.to(device)
encoder_loaded.eval()
with torch.no_grad():
    logits = encoder_loaded(x_tensor)# (1, K, T)
    probs = F.softmax(logits, dim=1).squeeze(0).cpu().numpy()# (K, T)
    regimes = np.argmax(probs, axis=0)# (T,)

print('probs.shape (K, T):', probs.shape)
print('regimes.shape:', regimes.shape)

# align with SP500 history (example)
import yfinance as yf
sp500 = yf.Ticker("^GSPC")
sp500_data = sp500.history(start="2020-01-01", end="2024-12-31")
prices = sp500_data['Close'].values[:end_day]


prices_aligned = prices[:len(regimes)]
index_aligned = sp500_data.index[:len(regimes)]

probs.shape (K, T): (5, 100)
regimes.shape: (100,)
