In [1]:
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

## load path
import os
import sys
sys.path.append('../')
sys.path.append('../models')
from pathlib import Path

## load utils 
from util.util import *

# we provide two realizations of GP on sphere, the commented one is slower
#from util.ofm_OT_likelihood_sphere_seq_mino import * 
from util.ofm_OT_likelihood_seq_mino import *
from util.metrics import *
import time

## load modules 
from models.mino_transformer import MINO
from models.mino_modules.decoder_perceiver import DecoderPerceiver
from models.mino_modules.encoder_supernodes_gno_cross_attention import EncoderSupernodes
from models.mino_modules.conditioner_timestep import ConditionerTimestep

In [2]:
query_dims = [16, 16]
x_dim = 2

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
spath = Path('./saved_models/MINO_T_Cylinder')

spath.mkdir(parents=True, exist_ok=True)
saved_model = True # save model

# GP hyperparameters
kernel_length=0.01
kernel_variance=1
nu = 0.5 # default

# model hyperparameters
## conditional time step.. 
dim = 256
num_heads=4

## training parameters
epochs = 300
sigma_min=1e-4 
batch_size = 96

In [3]:
# [batch_size, n_chan, n_seq]
# pos_data [bathc_size, x_dim, n_seq]
x_train = np.load('../dataset/cylinder/x_train.npy')
x_train = torch.Tensor(x_train).permute(0, 2, 1)

n_pos = np.load('../dataset/cylinder/pos_normalized.npy')
n_pos = torch.Tensor(n_pos)
pos_data = n_pos.unsqueeze(0).repeat(len(x_train), 1, 1).permute(0, 2, 1)

query_pos = make_2d_grid(query_dims).permute(1,0) #[2, 16x16]
train_dataset = SimDataset(x_train, pos_data, query_pos)

loader_tr =  DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=SimulationCollator,
)

## Model Initialization

In [4]:
# batch_size -> 

In [1]:
conditioner = ConditionerTimestep(
    dim=dim
)
model = MINO(
    conditioner=conditioner,
    encoder=EncoderSupernodes(
        input_dim=3, # co-domain 
        ndim=2, # dimension of domain
        radius= 0.07,
        enc_dim=dim,
        enc_num_heads=num_heads,
        enc_depth=5,
        cond_dim=conditioner.cond_dim,
    ),
    decoder=DecoderPerceiver(
        input_dim=dim,
        output_dim=3,
        ndim=2,
        dim=dim,
        num_heads=num_heads,
        depth=2, # 2 layers
        unbatch_mode="dense_to_sparse_unpadded",
        cond_dim=conditioner.cond_dim,
    ),
)
model = model.to(device)
#print(f"parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

In [6]:
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.8)
fmot = OFMModel(model, kernel_length=kernel_length, kernel_variance=kernel_variance, nu=nu, sigma_min=sigma_min, device=device, x_dim=x_dim, n_pos=n_pos)


In [19]:
fmot.train(loader_tr, optimizer, epochs=epochs, scheduler=scheduler, eval_int=int(0), save_int=int(300), generate=False, save_path=spath,saved_model=saved_model)

tr @ epoch 1/300 | Loss 0.705979 | 149.53 (s)
tr @ epoch 2/300 | Loss 0.439543 | 152.90 (s)
tr @ epoch 3/300 | Loss 0.275217 | 153.47 (s)
tr @ epoch 4/300 | Loss 0.231169 | 153.75 (s)
tr @ epoch 5/300 | Loss 0.209339 | 153.79 (s)
tr @ epoch 6/300 | Loss 0.189943 | 153.80 (s)
tr @ epoch 7/300 | Loss 0.164093 | 153.92 (s)
tr @ epoch 8/300 | Loss 0.136082 | 154.02 (s)
tr @ epoch 9/300 | Loss 0.120323 | 154.04 (s)
tr @ epoch 10/300 | Loss 0.111568 | 154.08 (s)
tr @ epoch 11/300 | Loss 0.105005 | 154.13 (s)
tr @ epoch 12/300 | Loss 0.100770 | 154.27 (s)
tr @ epoch 13/300 | Loss 0.096420 | 154.29 (s)
tr @ epoch 14/300 | Loss 0.091613 | 154.22 (s)
tr @ epoch 15/300 | Loss 0.089617 | 154.35 (s)
tr @ epoch 16/300 | Loss 0.087622 | 154.27 (s)
tr @ epoch 17/300 | Loss 0.084136 | 154.21 (s)
tr @ epoch 18/300 | Loss 0.083936 | 154.13 (s)
tr @ epoch 19/300 | Loss 0.081392 | 154.07 (s)
tr @ epoch 20/300 | Loss 0.081543 | 153.83 (s)
tr @ epoch 21/300 | Loss 0.076952 | 153.73 (s)
tr @ epoch 22/300 | Lo

tr @ epoch 174/300 | Loss 0.032361 | 153.67 (s)
tr @ epoch 175/300 | Loss 0.033135 | 153.67 (s)
tr @ epoch 176/300 | Loss 0.030696 | 153.62 (s)
tr @ epoch 177/300 | Loss 0.030979 | 153.60 (s)
tr @ epoch 178/300 | Loss 0.029967 | 153.63 (s)
tr @ epoch 179/300 | Loss 0.031005 | 153.66 (s)
tr @ epoch 180/300 | Loss 0.030272 | 153.61 (s)
tr @ epoch 181/300 | Loss 0.031102 | 153.65 (s)
tr @ epoch 182/300 | Loss 0.032430 | 153.89 (s)
tr @ epoch 183/300 | Loss 0.029687 | 154.14 (s)
tr @ epoch 184/300 | Loss 0.031529 | 154.21 (s)
tr @ epoch 185/300 | Loss 0.031119 | 154.14 (s)
tr @ epoch 186/300 | Loss 0.032615 | 154.09 (s)
tr @ epoch 187/300 | Loss 0.030944 | 154.10 (s)
tr @ epoch 188/300 | Loss 0.031199 | 154.06 (s)
tr @ epoch 189/300 | Loss 0.030461 | 154.06 (s)
tr @ epoch 190/300 | Loss 0.032551 | 153.95 (s)
tr @ epoch 191/300 | Loss 0.031372 | 154.00 (s)
tr @ epoch 192/300 | Loss 0.030019 | 153.98 (s)
tr @ epoch 193/300 | Loss 0.031050 | 154.01 (s)
tr @ epoch 194/300 | Loss 0.032434 | 154

## Evaluation

In [7]:
for param in model.parameters():
    param.requires_grad = False
    
model_path = os.path.join(spath, 'epoch_300.pt')
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint, strict=False)
fmot = OFMModel(model, kernel_length=kernel_length, kernel_variance=kernel_variance, nu=nu, sigma_min=sigma_min, device=device, x_dim=x_dim, n_pos=n_pos)


In [8]:
def gen_meta_info(batch_size, query_dims, n_pos=n_pos):

    pos_data = n_pos.unsqueeze(0).repeat(batch_size, 1, 1)
    
    query_n_pos = make_2d_grid(query_dims)
    query_pos_data = query_n_pos.unsqueeze(0).repeat(batch_size, 1, 1)

    collated_batch = {}


    collated_batch["input_pos"] = pos_data.permute(0, 2, 1)
    collated_batch['query_pos']= query_pos_data.permute(0, 2, 1)

    return collated_batch

In [None]:
start = time.time()

In [None]:
with torch.no_grad():

    X_alt = []
    for i in range(25):
        collated_batch =  gen_meta_info(batch_size=100, n_pos=n_pos, query_dims=query_dims)
        pos, query_pos = collated_batch['input_pos'], collated_batch['query_pos']
        X_temp = fmot.sample(pos=pos.to(device), query_pos=query_pos.to(device), n_samples=200, n_channels=3, n_eval=2).cpu()
    
        #X_temp = fmot.sample(pos=pos_data[:200].to(device), n_samples=200, n_eval=10).cpu()
        X_alt.append(X_temp)
        
    X_alt = torch.vstack(X_alt).squeeze()

In [None]:
end = time.time()
print(end-start)

## Metric

In [None]:
x_test =  np.load('../dataset/cylinder/x_test.npy')
x_test = torch.Tensor(x_test[:5000]).permute(0,2,1)

In [None]:
swd_value = swd_stable(X=X_alt, Y=x_test)

In [None]:
mmd_value = unbiased_mmd2_torch(X=X_alt, Y=x_test, device=device)  