In [19]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from torch.func import vmap, jacrev
from tqdm import tqdm
import os
import random
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
import math
from pydmd import DMD
from sklearn.preprocessing import MinMaxScaler
import warnings

In [25]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=128, n_layers=2, activation=nn.ReLU):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden_dim), activation()]
        for _ in range(n_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()])
        layers.append(nn.Linear(hidden_dim, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=128, n_layers=2):
        super().__init__()
        self.net = MLP(input_dim, latent_dim, hidden_dim, n_layers)

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim, hidden_dim=128, n_layers=2):
        super().__init__()
        self.net = MLP(latent_dim, output_dim, hidden_dim, n_layers)

    def forward(self, y):
        return self.net(y)
    
class ConvEncoder(nn.Module):
    def __init__(self, num_visible, num_hidden, hidden_size=32, kernel_size=31):
        super().__init__()
        self.num_visible = num_visible
        self.num_hidden = num_hidden
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv1d(in_channels=num_visible, out_channels=hidden_size, kernel_size=kernel_size)
        self.conv2 = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=hidden_size, out_channels=num_hidden, kernel_size=1)

    def forward(self, x): 
        x = x.permute(0, 2, 1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.tanh(self.conv3(x))
        x = x.permute(0, 2, 1)
        return x
        
class AuxiliaryKoopmanNet(nn.Module):
    def __init__(self, in_dim, koopman_type='complex', hidden_dim=64):
        super().__init__()
        self.koopman_type = koopman_type
        out_dim = 2 if koopman_type == 'complex' else 1
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, y):
        out = self.net(y)
        if self.koopman_type == 'complex':
            mu, omega = out[:, 0], out[:, 1]
            return mu, omega
        else:
            mu = out.squeeze(-1)
            return mu, None

def build_koopman_matrix(mu, omega):
    batch_size = mu.shape[0]
    dt = 1.0
    exp_mu = torch.exp(mu * dt)
    if omega is not None:
        cos_ = torch.cos(omega * dt)
        sin_ = torch.sin(omega * dt)
        K_blocks = torch.zeros(batch_size, 2, 2).to(mu.device)
        K_blocks[:, 0, 0] = exp_mu * cos_
        K_blocks[:, 0, 1] = -exp_mu * sin_
        K_blocks[:, 1, 0] = exp_mu * sin_
        K_blocks[:, 1, 1] = exp_mu * cos_
    else:
        K_blocks = torch.eye(2).unsqueeze(0).repeat(batch_size, 1, 1).to(mu.device)
        K_blocks = exp_mu.view(-1, 1, 1) * K_blocks
    return K_blocks

class DeepKoopman(nn.Module):
    def __init__(self, input_dim, num_real, num_complex,
                 hidden_dim=128, n_layers=2,
                 visible_dim=2, kernel_size=31, num_hidden=1):
        super().__init__()
        self.num_real = num_real
        self.num_complex = num_complex
        self.latent_dim = num_real * 1 + num_complex * 2
        self.visible_dim = visible_dim
        self.num_hidden = num_hidden
        self.encoder = Encoder(visible_dim + num_hidden, self.latent_dim, hidden_dim, n_layers)
        self.extra_encoder = ConvEncoder(visible_dim, num_hidden,
                                         hidden_size=hidden_dim, kernel_size=kernel_size)
        self.decoder = Decoder(self.latent_dim, visible_dim, hidden_dim, n_layers)
        self.aux_nets = nn.ModuleList()
        for _ in range(num_real):
            self.aux_nets.append(AuxiliaryKoopmanNet(in_dim=1, koopman_type='real'))
        for _ in range(num_complex):
            self.aux_nets.append(AuxiliaryKoopmanNet(in_dim=2, koopman_type='complex'))

    def split_latent(self, z):
        idx = 0
        parts = []
        for _ in range(self.num_real):
            parts.append(z[:, idx:idx+1])   # [B,1]
            idx += 1
        for _ in range(self.num_complex):
            parts.append(z[:, idx:idx+2])   # [B,2]
            idx += 2
        return parts

    def koopman_step(self, z):
        if z.dim() == 3:
            B, T, D = z.shape
            z = z.reshape(B*T, D)
            zs = self.split_latent(z)  
            updated = []
            for sub_z, net in zip(zs, self.aux_nets):
                mu, omega = net(sub_z)
                if sub_z.shape[1] == 1:
                    z_next = mu.unsqueeze(-1) * sub_z
                else:
                    Btot = sub_z.shape[0]
                    Bmat = build_koopman_matrix(mu, omega)   # [B*T,2,2]
                    z_rot = torch.bmm(Bmat, sub_z.unsqueeze(-1)).squeeze(-1)
                    z_next = z_rot
                updated.append(z_next)
            z_new = torch.cat(updated, dim=-1).reshape(B, T, -1)
        else:
            zs = self.split_latent(z)   # [B,1]/[B,2]
            updated = []
            for sub_z, net in zip(zs, self.aux_nets):
                mu, omega = net(sub_z)
                if sub_z.shape[1] == 1:
                    z_next = mu.unsqueeze(-1) * sub_z
                else:
                    B = sub_z.shape[0]
                    Bmat = build_koopman_matrix(mu, omega)
                    z_rot = torch.bmm(Bmat, sub_z.unsqueeze(-1)).squeeze(-1)
                    z_next = z_rot
                updated.append(z_next)
            z_new = torch.cat(updated, dim=-1)
        return z_new

    def koopman_power(self, z, m):
        for _ in range(m):
            z = self.koopman_step(z)
        return z

    def forward(self, x, steps=1, reverse=False):
        if reverse:
            raise NotImplementedError("Reverse is not used here.")
        interval = (self.extra_encoder.kernel_size - 1) // 2
        x_vis = x[:, interval:-interval, :self.visible_dim]        # [B, T', V]
        x_hid = self.extra_encoder(x[:, :, :self.visible_dim])     # [B, T', H]
        x_aug = torch.cat([x_vis, x_hid], dim=-1)
        z0 = self.encoder(x_aug)
        z1 = self.koopman_power(z0, steps)
        x_pred = self.decoder(z1)                                  # [B, T', V]

        return x_pred, z0, z1


In [21]:
def koopman_rollout_prediction(model, X):
    """
    Koopman rollout prediction over full sequence length T.

    Args:
        X: shape [B, T, dim]
    Returns:
        x_preds: shape [B, T, dim]
        y_preds: shape [B, T, latent_dim]
    """
    B, T, dim = X.shape
    x0 = X[:, 0, :] 
    y = model.encoder(x0)  # [B, latent_dim]

    x_preds = []
    y_preds = []

    for t in range(T):
        x_hat = model.decoder(y)
        x_preds.append(x_hat)
        y_preds.append(y)

        mu, omega = model.aux_net(y)
        K = build_koopman_matrix(mu, omega)  # [B, 2, 2]

        y_next = y.clone()
        y_rot = torch.bmm(K, y[:, :2].unsqueeze(-1)).squeeze(-1)
        y_next[:, :2] = y_rot
        y = y_next

    x_preds = torch.stack(x_preds, dim=1)  # [B, T, dim]
    y_preds = torch.stack(y_preds, dim=1)  # [B, T, latent_dim]
    return x_preds, y_preds

In [27]:
class TrainModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3,
                 alpha1=0.1, alpha2=1e-7, alpha3=1e-13, 
                 path="deepkoop_ckpt"):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.alpha3 = alpha3
        self.criterion_mse = nn.MSELoss()
        self.criterion_linf = lambda a, b: torch.max(torch.abs(a - b))
        self.path = path + ".ckpt"

        self.best_val_loss = float("inf")
        self.validation_outputs = []
        self.train_losses = []

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x):
        """
        x: [B, T, visible_dim]
        """
        B, T, V = x.shape
        interval = (self.model.extra_encoder.kernel_size - 1) // 2
        target = x[:, interval:-interval, :self.model.visible_dim]  # [B, T', V]
        T_prime = target.shape[1]

        x_pred, z0, z1 = self.model(x, steps=1)   # x_pred: [B,T',V], z0: [B,T',D]

        # Reconstruction Loss
        B, T_prime, D = z0.shape
        x_recon = self.model.decoder(z0.reshape(B * T_prime, D)).reshape(B, T_prime, -1)
        L_recon = self.criterion_mse(target, x_recon)

        # Future Prediction Loss
        L_pred = 0
        z_init = z0[:, 0, :]   
        for m in range(1, T_prime):
            z_pred = self.model.koopman_power(z_init, m)     # [B,D]
            x_m = self.model.decoder(z_pred)                 # [B,V]
            L_pred += self.criterion_mse(target[:, m], x_m)
        L_pred /= (T_prime - 1)

        # Linearity Loss
        L_lin = 0
        for m in range(T_prime - 1):
            z_next_pred = self.model.koopman_step(z0[:, m])  # [B,D] → [B,D]
            L_lin += self.criterion_mse(z0[:, m + 1], z_next_pred)
        L_lin /= (T_prime - 1)

        # L_inf Loss
        x1 = target[:, 0]
        x1_recon = x_recon[:, 0]
        x2 = target[:, 1]
        x2_pred = self.model.decoder(self.model.koopman_power(z_init, 1))
        L_inf = self.criterion_linf(x1, x1_recon) + self.criterion_linf(x2, x2_pred)

        # L2 regularization
        l2_reg = sum(torch.norm(p, 2) ** 2 for p in self.model.parameters() if p.requires_grad)

        loss = self.alpha1 * (L_recon + L_pred) + L_lin + self.alpha2 * L_inf + self.alpha3 * l2_reg
        return loss, L_recon, L_pred, L_lin, L_inf

    def training_step(self, batch, batch_idx):
        X = batch[0]  # [B, T, V]
        loss, L_recon, L_pred, L_lin, L_inf = self.compute_loss(X)
        self.log_dict({
            'train_loss': loss,
            'train_L_recon': L_recon,
            'train_L_pred': L_pred,
            'train_L_lin': L_lin,
            'train_L_inf': L_inf
        }, on_step=True, on_epoch=False, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X = batch[0]
        _, _, L_pred, _, _ = self.compute_loss(X)
        self.validation_outputs.append(L_pred)
        self.log('val_loss', L_pred)
        return L_pred

    def test_step(self, batch, batch_idx):
        X = batch[0]
        _, _, L_pred, _, _ = self.compute_loss(X)
        self.log("test_loss", L_pred)
        return L_pred

    def on_fit_start(self):
        if self.trainer.is_global_zero:
            if os.path.exists("loss_log.txt"):
                os.remove("loss_log.txt")
            if os.path.exists(self.path):
                os.remove(self.path)

    def on_train_epoch_end(self):
        if self.trainer.is_global_zero:
            avg_train_loss = self.trainer.callback_metrics.get("train_loss")
            if avg_train_loss is not None:
                self.train_losses.append(avg_train_loss.item())
                print(f"Epoch {self.current_epoch}: Average Training Loss = {avg_train_loss.item()}")

    def on_validation_epoch_end(self):
        avg_val_loss = torch.stack(self.validation_outputs).mean()
        self.log('avg_val_loss', avg_val_loss)
        self.validation_outputs.clear()
        print(f"Validation loss: {avg_val_loss}")
        with open("loss_log.txt", "a") as f:
            f.write(f"{avg_val_loss.item()}\n")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-8, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            },
            "gradient_clip_val": 1.0,
            "gradient_clip_algorithm": "norm"
        }

In [23]:
dim = 3 
hidden_dim = 10  
input_dim = 0
num_visible = 2
num_hidden = 1
hidden_size = 32
kernel_size = 31
n_blocks = 3
n_layers = 10
n_feature = 2
num_real = 3
num_complex = 1
rank = 3
batch_size = 512
n_train = 10000
n_valid = 1000
n_test = 1000
dropout = 0
num_epochs = 20
lamb = 0
learning_rate = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
X_train = pd.read_csv('KO_train.csv', header=None)
X_valid = pd.read_csv('KO_valid.csv', header=None)
X_test = pd.read_csv('KO_test.csv', header=None)
X_result = np.concatenate([X_train, X_test, X_valid], axis=-1)

In [None]:
warnings.filterwarnings("ignore")

idx_list = [[0, 1], [0, 2], [1, 2]]
for idx in idx_list:
    X_v_train = X_train.values[idx, :].reshape(2, -1)
    X_v_valid = X_valid.values[idx, :].reshape(2, -1)
    X_v_test = X_test.values[idx, :].reshape(2, -1)
    length = X_train.shape[1] // n_train
    H_train = []
    for i in range(n_train):
        H_train.append(X_v_train[:, i*length:(i+1)*length])
    H_train = np.stack([H_train[idx].T for idx in range(n_train)], axis=0)
    H_valid = []
    for i in range(n_valid):
        H_valid.append(X_v_valid[:, i*length:(i+1)*length])
    H_valid = np.stack([H_valid[idx].T for idx in range(n_valid)], axis=0)
    train_dataset = TensorDataset(torch.tensor(H_train, dtype=torch.float32))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    valid_dataset = TensorDataset(torch.tensor(H_valid, dtype=torch.float32))
    valid_loader = DataLoader(valid_dataset, batch_size=99999, shuffle=True, num_workers=8, pin_memory=True)
    full_dims = set(range(dim)) 
    missing_dim = list(full_dims - set(idx))[0]
    path = f"model_checkpoint_hidden_{missing_dim}_deepkoopman"
    checkpoint_callback = ModelCheckpoint(
        monitor="avg_val_loss",   
        dirpath="./",
        filename=path, 
        save_top_k=1,  
        mode="min",   
    )
    model = DeepKoopman(
        input_dim=num_visible + num_hidden,
        hidden_dim=hidden_dim,
        n_layers=n_layers,
        num_real=num_real,
        num_complex=num_complex,
        visible_dim=num_visible,
        kernel_size=kernel_size,
        num_hidden=num_hidden
    )
    lightning_model = TrainModel(model=model, learning_rate=learning_rate, path=path)
    trainer = pl.Trainer(accelerator="gpu", devices=1, strategy="ddp_notebook", max_epochs=num_epochs, callbacks=[checkpoint_callback])
    trainer.fit(lightning_model, train_loader, valid_loader)

In [16]:
encoder = Encoder(num_visible, num_hidden, hidden_size, kernel_size)
inn_model = InvertibleNN(dim=dim+n_feature, hidden_dim=hidden_dim, n_blocks=n_blocks, n_layers=n_layers, input_dim=input_dim, dropout=dropout, LDJ=lamb>0)
model = CombinedNetwork(encoder=encoder, inn_model=inn_model, input_dim=dim, visible_dim=num_visible, lifted_dim=n_feature, Xmax=Xmax, Xmin=Xmin)
path = "model_checkpoint_hidden_2.ckpt"
lightning_model = TrainModel.load_from_checkpoint(path, model=model, rank=rank, learning_rate=learning_rate, map_location="cpu")
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook", max_epochs=num_epochs)
# trainer = pl.Trainer(strategy="ddp_notebook", max_epochs=num_epochs)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
X_v_test = X_test.values[[0, 1], :].reshape(2, -1)
length = X_v_test.shape[1] // n_test
H_test = []
for i in range(n_test):
    H_test.append(X_v_test[:, i*length:(i+1)*length])
H_test = np.stack([H_test[idx].T for idx in range(n_test)], axis=0)
test_dataset = TensorDataset(torch.tensor(H_test))
test_loader = DataLoader(test_dataset, batch_size=9999, shuffle=True)

In [None]:
warnings.filterwarnings("ignore")
trainer.test(lightning_model, dataloaders=test_loader)