In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from torch.func import vmap, jacrev
from tqdm import tqdm
import os
import random
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
import math
from pydmd import DMD
from sklearn.preprocessing import MinMaxScaler
import warnings

In [2]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=128, n_layers=2, activation=nn.ReLU):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden_dim), activation()]
        for _ in range(n_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()])
        layers.append(nn.Linear(hidden_dim, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=128, n_layers=2):
        super().__init__()
        self.net = MLP(input_dim, latent_dim, hidden_dim, n_layers)

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim, hidden_dim=128, n_layers=2):
        super().__init__()
        self.net = MLP(latent_dim, output_dim, hidden_dim, n_layers)

    def forward(self, y):
        return self.net(y)

class AuxiliaryKoopmanNet(nn.Module):
    def __init__(self, koopman_type='complex', hidden_dim=64):
        super().__init__()
        self.koopman_type = koopman_type
        in_dim = 2 if koopman_type == 'complex' else 1  
        out_dim = 2 if koopman_type == 'complex' else 1
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, y):
        out = self.net(y)
        if self.koopman_type == 'complex':
            mu = out[:, 0]
            omega = out[:, 1]
            return mu, omega
        else:
            mu = out.squeeze(-1)
            return mu, None

def build_koopman_matrix(mu, omega):
    batch_size = mu.shape[0]
    dt = 1.0
    exp_mu = torch.exp(mu * dt)
    if omega is not None:
        cos_ = torch.cos(omega * dt)
        sin_ = torch.sin(omega * dt)
        K_blocks = torch.zeros(batch_size, 2, 2).to(mu.device)
        K_blocks[:, 0, 0] = exp_mu * cos_
        K_blocks[:, 0, 1] = -exp_mu * sin_
        K_blocks[:, 1, 0] = exp_mu * sin_
        K_blocks[:, 1, 1] = exp_mu * cos_
    else:
        K_blocks = torch.eye(2).unsqueeze(0).repeat(batch_size, 1, 1).to(mu.device)
        K_blocks = exp_mu.view(-1, 1, 1) * K_blocks
    return K_blocks

class DeepKoopman(nn.Module):
    def __init__(self, input_dim, num_real, num_complex, hidden_dim=128, n_layers=2):
        super().__init__()
        self.num_real = num_real
        self.num_complex = num_complex
        self.latent_dim = num_real * 1 + num_complex * 2

        self.encoder = Encoder(input_dim, self.latent_dim, hidden_dim, n_layers)
        self.decoder = Decoder(self.latent_dim, input_dim, hidden_dim, n_layers)
        self.aux_nets = nn.ModuleList()

        for _ in range(num_real):
            self.aux_nets.append(AuxiliaryKoopmanNet(koopman_type='real'))
        for _ in range(num_complex):
            self.aux_nets.append(AuxiliaryKoopmanNet(koopman_type='complex'))

    def split_latent(self, z):
        idx = 0
        parts = []
        for _ in range(self.num_real):
            parts.append(z[:, idx:idx+1])
            idx += 1
        for _ in range(self.num_complex):
            parts.append(z[:, idx:idx+2])
            idx += 2
        return parts

    def koopman_step(self, z):
        zs = self.split_latent(z)
        updated = []
        for i, (sub_z, net) in enumerate(zip(zs, self.aux_nets)):
            mu, omega = net(sub_z)
            if sub_z.shape[1] == 1:
                z_next = mu.unsqueeze(-1) * sub_z
            else:
                B = build_koopman_matrix(mu, omega)
                z_rot = torch.bmm(B, sub_z.unsqueeze(-1)).squeeze(-1)
                z_next = z_rot
            updated.append(z_next)
        return torch.cat(updated, dim=-1)

    def koopman_power(self, z, m):
        for _ in range(m):
            z = self.koopman_step(z)
        return z

    def forward(self, x, steps=1, reverse=False):
        if reverse:
            raise NotImplementedError("Reverse is not used here.")
        z0 = self.encoder(x)
        z1 = self.koopman_power(z0, steps)
        x_pred = self.decoder(z1)
        return x_pred, z0, z1

In [3]:
def koopman_rollout_prediction(model, X):
    """
    Koopman rollout prediction over full sequence length T.

    Args:
        X: shape [B, T, dim]
    Returns:
        x_preds: shape [B, T, dim]
        y_preds: shape [B, T, latent_dim]
    """
    B, T, dim = X.shape
    x0 = X[:, 0, :]  
    y = model.encoder(x0)  # [B, latent_dim]

    x_preds = []
    y_preds = []

    for t in range(T):
        x_hat = model.decoder(y)
        x_preds.append(x_hat)
        y_preds.append(y)

        mu, omega = model.aux_net(y)
        K = build_koopman_matrix(mu, omega)  # [B, 2, 2]

        y_next = y.clone()
        y_rot = torch.bmm(K, y[:, :2].unsqueeze(-1)).squeeze(-1)
        y_next[:, :2] = y_rot
        y = y_next

    x_preds = torch.stack(x_preds, dim=1)  # [B, T, dim]
    y_preds = torch.stack(y_preds, dim=1)  # [B, T, latent_dim]
    return x_preds, y_preds

In [14]:
class TrainModel(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-3,
                 alpha1=0.1, alpha2=1e-7, alpha3=1e-13, 
                 path="deepkoop_ckpt"):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.alpha3 = alpha3
        self.criterion_mse = nn.MSELoss()
        self.criterion_linf = lambda a, b: torch.max(torch.abs(a - b))
        self.path = path + ".ckpt"

        self.best_val_loss = float("inf")
        self.validation_outputs = []
        self.train_losses = []

    def forward(self, x):
        return self.model(x)

    def compute_loss(self, x):
        B, T, dim = x.shape
        x1 = x[:, 0]
        x2 = x[:, 1]
        z1 = self.model.encoder(x1)
        x1_recon = self.model.decoder(z1)

        # Reconstruction Loss
        L_recon = self.criterion_mse(x1, x1_recon)

        # Future Prediction Loss
        L_pred = 0
        for m in range(1, T):
            z_pred = self.model.koopman_power(z1, m)  # K^m φ(x1)
            x_pred = self.model.decoder(z_pred)
            L_pred += self.criterion_mse(x[:, m], x_pred)
        L_pred /= (T - 1)

        # Linearity Loss
        L_lin = 0
        z_all = self.model.encoder(x)
        for m in range(T - 1):
            z_next_pred = self.model.koopman_step(z_all[:, m])
            L_lin += self.criterion_mse(z_all[:, m + 1], z_next_pred)
        L_lin /= (T - 1)

        # L_inf Loss
        x2_pred = self.model.decoder(self.model.koopman_power(z1, 1))
        L_inf = self.criterion_linf(x1, x1_recon) + self.criterion_linf(x2, x2_pred)

        # L2 regularization
        l2_reg = sum(torch.norm(p, 2) ** 2 for p in self.model.parameters() if p.requires_grad)

        loss = self.alpha1 * (L_recon + L_pred) + L_lin + self.alpha2 * L_inf + self.alpha3 * l2_reg
        return loss, L_recon, L_pred, L_lin, L_inf

    def training_step(self, batch, batch_idx):
        X = batch[0]  # shape: [B, T, dim]
        loss, L_recon, L_pred, L_lin, L_inf = self.compute_loss(X)
        self.log_dict({
            'train_loss': loss,
            'train_L_recon': L_recon,
            'train_L_pred': L_pred,
            'train_L_lin': L_lin,
            'train_L_inf': L_inf
        }, on_step=True, on_epoch=False, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X = batch[0] 
        _, _, L_pred, _, _ = self.compute_loss(X)
        self.validation_outputs.append(L_pred)
        self.log('val_loss', L_pred)
        return L_pred

    def test_step(self, batch, batch_idx):
        X = batch[0]  # shape: [B, T, dim]
        _, _, L_pred, _, _ = self.compute_loss(X)
        self.log("test_loss", L_pred)
        return L_pred

    def on_fit_start(self):
        if self.trainer.is_global_zero:
            if os.path.exists("loss_log.txt"):
                os.remove("loss_log.txt")
            if os.path.exists(self.path):
                os.remove(self.path)

    def on_train_epoch_end(self):
        if self.trainer.is_global_zero:
            avg_train_loss = self.trainer.callback_metrics.get("train_loss")
            if avg_train_loss is not None:
                self.train_losses.append(avg_train_loss.item())
                print(f"Epoch {self.current_epoch}: Average Training Loss = {avg_train_loss.item()}")

    def on_validation_epoch_end(self):
        avg_val_loss = torch.stack(self.validation_outputs).mean()
        self.log('avg_val_loss', avg_val_loss)
        self.validation_outputs.clear()
        print(f"Validation loss: {avg_val_loss}")
        with open("loss_log.txt", "a") as f:
            f.write(f"{avg_val_loss.item()}\n")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-8, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            },
            "gradient_clip_val": 1.0,
            "gradient_clip_algorithm": "norm"
        }

In [5]:
dim = 3  
hidden_dim = 40  
input_dim = 0
n_blocks = 3  
n_layers = 3
num_real = 3
num_complex = 1
batch_size = 512
n_train = 10000
n_valid = 1000
n_test = 1000
dropout = 0
num_epochs = 20  
lamb = 0
learning_rate = 1e-3  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
X_train = pd.read_csv('KO_sample_train.csv', header=None).values
X_valid = pd.read_csv('KO_sample_valid.csv', header=None).values
X_test = pd.read_csv('KO_sample_test.csv', header=None).values

length = X_train.shape[1] // n_train
H_train = []
for i in range(n_train):
    H_train.append(X_train[:, i*length:(i+1)*length])
H_train = np.stack([H_train[idx].T for idx in range(n_train)], axis=0)
H_valid = []
for i in range(n_valid):
    H_valid.append(X_valid[:, i*length:(i+1)*length])
H_valid = np.stack([H_valid[idx].T for idx in range(n_valid)], axis=0)

In [None]:
warnings.filterwarnings("ignore")
sample_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
# sample_list = [1]
for i in sample_list:
    path = f"model_checkpoint_KO_sample_{i}"
    checkpoint_callback = ModelCheckpoint(
        monitor="avg_val_loss",   
        dirpath="./", 
        filename=path,  
        save_top_k=1,  
        mode="min",    
    )
    H_train_tensor = torch.tensor(H_train[:, 0:11*i:i, :], dtype=torch.float32)
    train_dataset = TensorDataset(H_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    H_valid_tensor = torch.tensor(H_valid[:, 0:11*i:i, :], dtype=torch.float32)
    valid_dataset = TensorDataset(H_valid_tensor)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    model = DeepKoopman(input_dim=dim, hidden_dim=hidden_dim, n_layers=n_layers, num_real=num_real, num_complex=num_complex)
    lightning_model = TrainModel(model=model, learning_rate=learning_rate, path=path)
    trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook", max_epochs=num_epochs, callbacks=[checkpoint_callback])

    trainer.fit(lightning_model, train_loader, valid_loader)

In [12]:
X_test = pd.read_csv('KO_sample_test.csv', header=None).values
length = X_test.shape[1] // n_test
H_test = []
for i in range(n_test):
    H_test.append(X_test[:, i*length:(i+1)*length])
H_test = np.stack([H_test[idx].T for idx in range(n_test)], axis=0)

In [None]:
warnings.filterwarnings("ignore")
sample_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
error_list = []

for i in sample_list:
    path = f"model_checkpoint_KO_sample_{i}.ckpt"
    model = DeepKoopman(input_dim=dim, hidden_dim=hidden_dim, n_layers=n_layers, num_real=num_real, num_complex=num_complex)
    lightning_model = TrainModel.load_from_checkpoint(path, model=model, learning_rate=learning_rate, map_location="cpu")
    trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook", max_epochs=num_epochs)
    H_test_tensor = torch.tensor(H_test[:, 0:11*i:i, :], dtype=torch.float32)
    test_dataset = TensorDataset(H_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=9999, shuffle=True)
    error_list.append(trainer.test(lightning_model, dataloaders=test_loader)[0]['test_loss'])

df = pd.DataFrame(error_list)
df.to_csv("sample.csv", index=False)

In [12]:
error_list

[0.34521573781967163,
 0.34112581610679626,
 0.3411276042461395,
 0.35121089220046997,
 0.3354300558567047,
 0.34621691703796387,
 0.34132254123687744,
 0.3509877324104309,
 0.3417041301727295,
 0.34397822618484497,
 0.35606205463409424,
 0.3498133718967438,
 0.3536776006221771,
 0.34368571639060974,
 0.35125532746315,
 0.34510818123817444,
 0.3392789363861084,
 0.3365575969219208,
 0.3411751687526703]