In [2]:
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

## load path
import os
import sys
sys.path.append('../')
sys.path.append('../models')
from pathlib import Path

## load utils 
from util.util import *

# we provide two realizations of GP on sphere, the commented one is slower
#from util.ofm_OT_likelihood_sphere_seq_mino import * 

from util.ofm_OT_likelihood_seq_mino import *
import statsmodels.api as sm
from scipy.stats import binned_statistic
import matplotlib.tri as tri
from util.metrics import *
import time

## load modules 
from models.mino_unet import MINO
from models.mino_modules.decoder_perceiver import DecoderPerceiver
from models.mino_modules.encoder_supernodes_gno_cross_attention import EncoderSupernodes
from models.mino_modules.conditioner_timestep import ConditionerTimestep

from models.mino_modules.modules.unet_nD import UNetModelWrapper

In [3]:
query_dims = [32,16] # latent
x_dim = 3

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#spath = Path('../trash/GP')
spath = Path('./saved_models/MINO_U_Climate')

spath.mkdir(parents=True, exist_ok=True)
saved_model = True # save model

# GP hyperparameters
kernel_length=0.05 ## maybe kernel length = 0.02 # (or 0.01)
kernel_variance=1
nu = 0.5 # default


# model hyperparameters
dim = 256 # sup-node
num_heads = 4
unet_dims = (dim, *query_dims) 
unet_channels = 96 
num_res_blocks=1
attention_res = '4'


## training parameters
epochs = 480
sigma_min=1e-4
batch_size = 48

In [4]:
_NUM_LONGS = 90
_NUM_LATS = 46 

############
# generate a grid on the sphere
longs, lats =  np.mgrid[0:2*np.pi:(_NUM_LONGS+1)*1j, 0:np.pi:_NUM_LATS*1j]

longs, lats = longs[:-1,:], lats[:-1,:]

other_points_xs = np.sin(lats) * np.cos(longs)
other_points_ys = np.sin(lats) * np.sin(longs)
other_points_zs = np.cos(lats)

other_points = np.c_[np.ravel(other_points_xs),
                     np.ravel(other_points_ys),
                     np.ravel(other_points_zs)]

n_pos = other_points


In [5]:
def latent_query_sphere(num_longs, num_lats):
    longs, lats =  np.mgrid[0:2*np.pi:(num_longs+1)*1j, 0:np.pi:num_lats*1j]

    longs, lats = longs[:-1,2:-2], lats[:-1,2:-2]

    print("longs:{}".format(longs.shape))
    other_points_xs = np.sin(lats) * np.cos(longs)
    other_points_ys = np.sin(lats) * np.sin(longs)
    other_points_zs = np.cos(lats)

    query_pos = np.c_[np.ravel(other_points_xs),
                         np.ravel(other_points_ys),
                         np.ravel(other_points_zs)]
    
    return query_pos

In [6]:
#query_pos = latent_query_sphere(num_longs=48, num_lats=28) #(32, 16) #->21144
query_pos_input = latent_query_sphere(num_longs=32, num_lats=20)

longs:(32, 16)


In [7]:
x_train = np.load('../dataset/weather/train_climate.npy')
x_train = torch.Tensor(x_train[:,2:3]).permute(0,1,3,2) # 0:2 (longitude, latitude)
x_train = torch.flatten(x_train, start_dim=2)


n_pos = torch.Tensor(n_pos)
pos_data = n_pos.unsqueeze(0).repeat(len(x_train), 1, 1).permute(0, 2, 1)

query_pos_input = torch.Tensor(query_pos_input).permute(1,0)
train_dataset = SimDataset(x_train, pos_data, query_pos_input)

loader_tr =  DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=SimulationCollator,
)

## Model Initialization

In [1]:
conditioner = ConditionerTimestep(
    dim=dim
)
model = MINO(
    conditioner=conditioner,
    encoder=EncoderSupernodes(
        input_dim=1,
        ndim=3,
        radius= 0.2,

        enc_dim=dim,
        enc_num_heads=num_heads,
        enc_depth=2,
        cond_dim=conditioner.cond_dim,
    ),
    

    processor=UNetModelWrapper(dim=unet_dims, num_channels=unet_channels,
                                          num_res_blocks=num_res_blocks,
                                          num_heads=num_heads, set_cond=False,
                                          attention_resolutions=attention_res),
    
    decoder=DecoderPerceiver(
        input_dim=dim,
        output_dim=1,
        ndim=3,
        dim=dim,
        num_heads=num_heads,
        depth=2,
        unbatch_mode="dense_to_sparse_unpadded",
        cond_dim=conditioner.cond_dim,
    ),
)
model = model.to(device)
#print(f"parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

In [9]:
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.8)
fmot = OFMModel(model, kernel_length=kernel_length, kernel_variance=kernel_variance, nu=nu, sigma_min=sigma_min, device=device, x_dim=x_dim, n_pos=n_pos)

In [2]:
fmot.train(loader_tr, optimizer, epochs=epochs, scheduler=scheduler, eval_int=int(0), save_int=int(480), generate=False, save_path=spath,saved_model=saved_model)

## Evaluation

In [11]:
for param in model.parameters():
    param.requires_grad = False

model_path = os.path.join(spath, 'epoch_480.pt')
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint, strict=False)
fmot = OFMModel(model, kernel_length=kernel_length, kernel_variance=kernel_variance, nu=nu, sigma_min=sigma_min, device=device, x_dim=x_dim, n_pos=n_pos)


In [15]:

query_pos_latent = query_pos_input

In [19]:
def gen_meta_info(batch_size, query_pos, n_pos):

    n_pos = n_pos
    pos_data = n_pos.unsqueeze(0).repeat(batch_size, 1, 1)
    
    query_pos_data = query_pos.unsqueeze(0).repeat(batch_size, 1, 1)

    collated_batch = {}


    collated_batch["input_pos"] = pos_data.permute(0, 2, 1)
    collated_batch['query_pos']= query_pos_data


    
    return collated_batch

In [None]:
import time
start = time.time()

with torch.no_grad():

    X_alt = []
    for i in range(26):
        collated_batch =  gen_meta_info(batch_size=100, n_pos=n_pos, query_pos=query_pos_input)
        pos, query_pos = collated_batch['input_pos'], collated_batch['query_pos']
        X_temp = fmot.sample(pos=pos.to(device), query_pos=query_pos.to(device), n_samples=100, n_channels=1, n_eval=5).cpu()
    
        #X_temp = fmot.sample(pos=pos_data[:200].to(device), n_samples=200, n_eval=10).cpu()
        X_alt.append(X_temp)
        
    X_alt = torch.vstack(X_alt).squeeze()
    
end = time.time()
print(end-start)

## Metrics

In [None]:
x_test =  np.load('../dataset/weather/test_climate.npy')

x_test = torch.Tensor(x_test[:,2:3]).permute(0,1,3,2) # (longitude, latitude)
x_test = torch.flatten(x_test, start_dim=2)


In [None]:
swd_value = swd_stable(X=X_alt, Y=x_test)

In [None]:
mmd_value = unbiased_mmd2_torch(X=X_alt, Y=x_test, device=device)  