In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from torch.func import vmap, jacrev
from tqdm import tqdm
import os
import random
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback
import math
from pydmd import DMD
from sklearn.preprocessing import MinMaxScaler
import warnings

In [2]:
class ResidualFlow(nn.Module):
    def __init__(self, dim, hidden_dim, n_layers, input_dim=0, dropout=0, LDJ=False, block_id=0):
        super().__init__()
        self.dim = dim
        self.input_dim = input_dim
        self.LDJ = LDJ
        self.block_id = block_id

        flip = (block_id % 2 == 0)
        self.flow = Flow(dim, hidden_dim, flip=flip)

    def forward(self, x, reverse=False):
        x_e = x
        if not reverse:
            y = self.flow(x_e, reverse=False)
            logdet = 0
            return y, logdet
        else:
            y = self.flow(x_e, reverse=True)
            return y

class Flow(nn.Module):
    def __init__(self, in_channel, hidden_dim, flip=False):
        super().__init__()
        self.coupling = AffineCoupling(in_channel, hidden_dim, flip)

    def forward(self, x, reverse=False):
        return self.coupling(x, reverse)

class AffineCoupling(nn.Module):
    def __init__(self, dim, hidden_dim, flip=False):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.flip = flip

        self.split_idx = dim // 2
        self.rest_dim = dim - self.split_idx

        self.net = nn.Sequential(
            nn.Linear(self.split_idx, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.rest_dim * 2)
        )

    def forward(self, x, reverse=False):
        if self.flip:
            x2, x1 = torch.split(x, [self.rest_dim, self.split_idx], dim=-1)
        else:
            x1, x2 = torch.split(x, [self.split_idx, self.rest_dim], dim=-1)

        h = self.net(x1)
        s, t = torch.chunk(h, 2, dim=-1)
        s = torch.tanh(s)

        if not reverse:
            y2 = x2 * torch.exp(s) + t
        else:
            y2 = (x2 - t) * torch.exp(-s)

        if self.flip:
            return torch.cat([y2, x1], dim=-1)
        else:
            return torch.cat([x1, y2], dim=-1)
    
class InvertibleNN(nn.Module):
    def __init__(self, dim, hidden_dim, n_blocks, n_layers, input_dim=0, dropout=0, LDJ=False):
        super(InvertibleNN, self).__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.n_blocks = n_blocks
        self.n_layers = n_layers
        self.input_dim = input_dim
        self.blocks = nn.ModuleList([ResidualFlow(self.dim, self.hidden_dim, self.n_layers, self.input_dim, dropout, LDJ, block_id=i) for i in range(self.n_blocks)])
    
    def forward(self, x, u=None, reverse=False):
        if not reverse:
            ldj_total = 0
            for block in self.blocks:
                x, ldj = block(x, reverse)
                ldj_total += ldj
            return x, ldj_total
        else:
            for block in reversed(self.blocks):
                x = block(x, reverse)
            return x
    
class CombinedNetwork(nn.Module):
    def __init__(self, inn_model, input_dim, lifted_dim):
        super(CombinedNetwork, self).__init__()
        self.input_dim = input_dim
        self.inn_model = inn_model  
        self.lifted_dim = lifted_dim
    
    def forward(self, x, u=None, reverse=False):
        x = x.float()
        if not reverse:
            zero_pad = torch.zeros(x.shape[0], x.shape[1], self.lifted_dim, device=x.device)
            x = torch.cat((x, zero_pad), dim=-1)
            x, ldj = self.inn_model(x, u, reverse)
            return x, ldj
        else:
            x = self.inn_model(x, u, reverse)
            x = x[:, :self.input_dim]
            return x

In [3]:
def dmd(model, X, rank):
    GX_pred_list = []
    GX_list = []
    GX, ldj = model(X)
    for i in range(X.shape[0]):
        GX_temp = GX[i, :, :].T
        dmd = DMD(svd_rank=rank, exact=True, sorted_eigs='abs')
        dmd.fit(GX_temp.cpu().detach().numpy())
        GX_pred = dmd.reconstructed_data.real
        GX_pred = np.array(GX_pred, dtype=np.float32)
        GX_pred = torch.from_numpy(GX_pred).cuda()
        GX_pred_list.append(GX_pred)
        GX_list.append(GX_temp)
    GX_pred = torch.cat(GX_pred_list, dim=-1)
    GX = torch.cat(GX_list, dim=-1)

    return GX, GX_pred, ldj

In [4]:
class TrainModel(pl.LightningModule):
    def __init__(self, model, rank, learning_rate=1e-3, lamb=1, path="model_checkpoint_Van"):
        super(TrainModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.criterion = nn.MSELoss()
        self.best_val_loss = float('inf')  
        self.validation_outputs = []
        self.lamb = lamb
        self.train_losses = []
        self.rank = rank
        self.path = path+'.ckpt'

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        X_batch = batch[0]
        GY, GY_pred, ldj = dmd(self.model, X_batch, self.rank)

        loss_lin = self.criterion(GY, GY_pred)
        loss_LDJ = ldj / X_batch.numel()

        loss = loss_lin - self.lamb * loss_LDJ
        self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        Z_batch = batch[0]
        Z1, Z_pred, _ = dmd(self.model, Z_batch, self.rank)
        Z_pred = self.model(Z_pred.T, reverse=True)
        Z1 = self.model(Z1.T, reverse=True)
        valid_loss = self.criterion(Z_pred, Z1)

        self.validation_outputs.append(valid_loss)
        self.log('val_loss', valid_loss)
        return valid_loss

    def test_step(self, batch, batch_idx):
        Z_batch = batch[0]
        Z1, Z_pred, _ = dmd(self.model, Z_batch, self.rank)
        Z_pred = self.model(Z_pred.T, reverse=True)
        Z1 = self.model(Z1.T, reverse=True)
        test_loss = self.criterion(Z_pred, Z1)

        self.log('test_loss', test_loss)
        return test_loss
    
    def on_fit_start(self):
        if self.trainer.is_global_zero: 
            if os.path.exists("loss_log.txt"):
                os.remove("loss_log.txt")
            if os.path.exists(self.path):
                os.remove(self.path)
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        with torch.no_grad():  
            for name, module in self.model.named_modules():  
                if isinstance(module, nn.Linear): 
                    if name == "linear":  
                        continue
                    weight = module.weight  
                    sigma_max = torch.norm(weight, p=2)  
                    if sigma_max > 1:  
                        scale = (1 - 1e-3) / sigma_max
                        module.weight.data *= scale  
    
    def on_train_epoch_start(self):
        if os.path.exists(self.path):
            best_state_dict = torch.load(self.path)["state_dict"]
            self.load_state_dict(best_state_dict)
    
    def on_train_epoch_end(self):
        if self.trainer.is_global_zero: 
            avg_train_loss = self.trainer.callback_metrics.get("train_loss")
            if avg_train_loss is not None:
                self.train_losses.append(avg_train_loss.item())  
                print(f"Epoch {self.current_epoch}: Average Training Loss = {avg_train_loss.item()}")

    def on_validation_epoch_end(self):
        avg_val_loss = torch.stack(self.validation_outputs).mean() 
        self.log('avg_val_loss', avg_val_loss)
        self.validation_outputs.clear()
        print(f"Validation loss: {avg_val_loss}")
        with open("loss_log.txt", "a") as f:
            f.write(f"{avg_val_loss.item()}\n")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-08,
                                            weight_decay=0)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=1,
            gamma=0.92
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",  
            },
            "gradient_clip_val": 1.0,  
            "gradient_clip_algorithm": "norm",
        }

In [5]:
dim = 2  
hidden_dim = 50  
input_dim = 0
n_blocks = 3 
n_layers = 1
n_feature = 48
rank = 8
batch_size = 2048
n_train = 50000
n_valid = 5000
n_test = 5000
dropout = 0.5
num_epochs = 99999  
lamb = 0
learning_rate = 1e-3  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
X_train = pd.read_csv('Van_train.csv', header=None).values
X_valid = pd.read_csv('Van_valid.csv', header=None).values
X_test = pd.read_csv('Van_test.csv', header=None).values

length = X_train.shape[1] // n_train
H_train = []
for i in range(n_train):
    H_train.append(X_train[:, i*length:(i+1)*length])
H_train = np.stack([H_train[idx].T for idx in range(n_train)], axis=0)
H_valid = []
for i in range(n_valid):
    H_valid.append(X_valid[:, i*length:(i+1)*length])
H_valid = np.stack([H_valid[idx].T for idx in range(n_valid)], axis=0)
train_dataset = TensorDataset(torch.tensor(H_train))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
valid_dataset = TensorDataset(torch.tensor(H_valid))
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

X_result = np.concatenate([X_train, X_test, X_valid], axis=-1)
Xmax = torch.tensor(np.max(X_result, axis=-1), dtype=torch.float)
Xmin = torch.tensor(np.min(X_result, axis=-1), dtype=torch.float)

In [None]:
warnings.filterwarnings("ignore")
path = "model_checkpoint_Van_flowdmd"
checkpoint_callback = ModelCheckpoint(
    monitor="avg_val_loss",   
    dirpath="./",  
    filename=path,  
    save_top_k=1, 
    mode="min",    
)
inn_model = InvertibleNN(dim=dim+n_feature, hidden_dim=hidden_dim, n_blocks=n_blocks, n_layers=n_layers, input_dim=input_dim, dropout=dropout, LDJ=lamb>0)
model = CombinedNetwork(inn_model=inn_model, input_dim=dim, lifted_dim=n_feature)
lightning_model = TrainModel(model=model, rank=rank, learning_rate=learning_rate, lamb=lamb, path=path)
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook", max_epochs=num_epochs, callbacks=[checkpoint_callback])

trainer.fit(lightning_model, train_loader, valid_loader)

In [8]:
inn_model = InvertibleNN(dim=dim+n_feature, hidden_dim=hidden_dim, n_blocks=n_blocks, n_layers=n_layers, input_dim=input_dim, dropout=dropout, LDJ=lamb>0)
model = CombinedNetwork(inn_model=inn_model, input_dim=dim, lifted_dim=n_feature)
path = "model_checkpoint_Van_flowdmd.ckpt"
lightning_model = TrainModel.load_from_checkpoint(path, model=model, rank=rank, learning_rate=learning_rate, map_location="cpu")
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook", max_epochs=num_epochs)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
length = X_test.shape[1] // n_test
H_test = []
for i in range(n_test):
    H_test.append(X_test[:, i*length:(i+1)*length])
H_test = np.stack([H_test[idx].T for idx in range(n_test)], axis=0)
test_dataset = TensorDataset(torch.tensor(H_test))
test_loader = DataLoader(test_dataset, batch_size=9999, shuffle=True)

In [10]:
warnings.filterwarnings("ignore")
trainer.test(lightning_model, dataloaders=test_loader)

You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. 

Testing DataLoader 0: 100%|██████████| 1/1 [00:03<00:00,  0.29it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss                   inf
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': inf}]