In [1]:
import os
import sys
sys.path.append(os.path.abspath("..")) 

from model import *

import math
import time
import datetime
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from matcho import Unet2D
# from YourDataset import YourDataset  # Import your custom dataset here
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary
import torchprofile

import pickle

torch.manual_seed(23)

scaler = GradScaler()

DTYPE = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "serif"

import scipy.stats as stats

Using device: cuda


In [2]:
def compute_power(true, pred, inp):
    BS, nt, nx, ny = true.shape
    
    # Compute the Fourier transforms and amplitude squared for both true and pred
    fourier_true = torch.fft.fftn(true, dim=(-2, -1))
    fourier_pred = torch.fft.fftn(pred, dim=(-2, -1))
    fourier_inp  = torch.fft.fftn(inp , dim=(-2, -1))

    # Get the squared amplitudes
    amplitudes_true = torch.abs(fourier_true) #** 2
    amplitudes_pred = torch.abs(fourier_pred) #** 2
    amplitudes_inp = torch.abs(fourier_inp) #** 2

    # Create the k-frequency grids
    kfreq_y = torch.fft.fftfreq(ny) * ny
    kfreq_x = torch.fft.fftfreq(nx) * nx
    kfreq2D_x, kfreq2D_y = torch.meshgrid(kfreq_x, kfreq_y, indexing='ij')
    
    # Compute the wavenumber grid
    knrm = torch.sqrt(kfreq2D_x ** 2 + kfreq2D_y ** 2).to(true.device)
    
    # Define the bins for the wavenumber
    kbins = torch.arange(0.5, nx // 2 + 1, 1.0, device=true.device)
    
    # Digitize knrm to bin indices
    knrm_flat = knrm.flatten()
    bin_indices = torch.bucketize(knrm_flat, kbins)

    # Reshape and flatten the amplitudes
    amplitudes_true_flat = amplitudes_true.view(BS, nt, nx * ny)
    amplitudes_pred_flat = amplitudes_pred.view(BS, nt, nx * ny)
    amplitudes_inp_flat  = amplitudes_inp.view(BS, nt, nx * ny)

    # Initialize Abins
    Abins_true = torch.zeros((BS, nt, len(kbins) - 1), device=true.device)
    Abins_pred = torch.zeros((BS, nt, len(kbins) - 1), device=pred.device)
    Abins_inp  = torch.zeros((BS, nt, len(kbins) - 1), device= inp.device)

    # Vectorized binning: sum up the values in each bin
    for bin_idx in range(1, len(kbins)):
        mask = (bin_indices == bin_idx).unsqueeze(0).unsqueeze(0)  # Create a mask for each bin
        Abins_true[:, :, bin_idx - 1] = (amplitudes_true_flat * mask).sum(dim=-1) / mask.sum(dim=-1)
        Abins_pred[:, :, bin_idx - 1] = (amplitudes_pred_flat * mask).sum(dim=-1) / mask.sum(dim=-1)
        Abins_inp[:,  :, bin_idx - 1] = (amplitudes_inp_flat  * mask).sum(dim=-1) / mask.sum(dim=-1)

    # Scale the binned amplitudes
    scaling_factor = torch.pi * (kbins[1:] ** 2 - kbins[:-1] ** 2)
    Abins_true *= scaling_factor
    Abins_pred *= scaling_factor
    Abins_inp  *= scaling_factor

    return Abins_true, Abins_pred, Abins_inp

def plot_power_spectrum(power_inp, power_true, power_pred, inp, true, pred, epoch, err):
    f = 2
    fig, axes = plt.subplots(1, 4, figsize=(4*f, 1*f))
    # t_ls = np.arange(power_true.shape[1])
    # skip_t = 12
    # time_ls = t_ls[::skip_t][1:]

    sample_id=8
    t_id = 0
    for i in range(1):
        x = torch.arange(true.shape[-2]//2)
        axes[0].loglog(x, power_true[sample_id,t_id], label='true', c='black')
        axes[0].loglog(x, power_inp[sample_id,t_id], label='NO', c='blue')
        axes[0].loglog(x, power_pred[sample_id,t_id], label='adv. NO', c='red')
        # axes[i].set_title(f"t: {0}")
        axes[0].set_xlabel(r'$k$')
        if i==0:
            axes[0].legend()
        if i==0:
            axes[0].set_ylabel(r'$P(k)$')
    

    inp_sample = inp[sample_id, t_id]
    true_sample = true[sample_id, t_id]
    pred_sample = pred[sample_id, t_id]
    vmin, vmax = true_sample.min(), true_sample.max()
    im1 = axes[1].imshow(true_sample, vmin=vmin, vmax=vmax, cmap=CMAP)
    axes[1].set_title("True")
    axes[1].set_xticks([])
    axes[1].set_yticks([])

    im = axes[2].imshow(inp_sample, vmin=vmin, vmax=vmax, cmap=CMAP)
    axes[2].set_title("NO")
    axes[2].set_xticks([])
    axes[2].set_yticks([])

    im = axes[3].imshow(pred_sample, vmin=vmin, vmax=vmax, cmap=CMAP)
    axes[3].set_title("NO+VAE")
    axes[3].set_xticks([])
    axes[3].set_yticks([])
    fig.colorbar(im1, ax=axes[3])
    plt.tight_layout()


    fig.suptitle(f"Epoch: {epoch}, MSE: {err:.2e}", fontsize=22, y=1.2)
    plt.savefig(f"power_spectrum/{epoch}.png", dpi=150, bbox_inches='tight')
    plt.close()

def error_metric(inp, pred,true, epoch, Par={}, is_plot=True):
    #re-normalize
    # true = true*Par['out_scale'] + Par['out_shift']
    # true = true*Par['out_scale'] + Par['out_shift']

    # inp = inp*Par['inp_scale'] + Par['inp_shift']
    # inp = (inp - Par['out_shift'])/Par['out_scale']

    power_inp, power_true, power_pred = compute_power(inp, true, pred)
    err = torch.mean( (torch.log(power_true)-torch.log(power_pred) )**2 )
    f_err = torch.norm(true-pred, p=2)/torch.norm(true, p=2)
    ref_err = torch.norm(true-inp, p=2)/torch.norm(true, p=2)
    if is_plot:
        plot_power_spectrum(power_inp.detach().cpu().numpy(), power_true.detach().cpu().numpy(), power_pred.detach().cpu().numpy(), inp.detach().cpu().numpy(), true.detach().cpu().numpy(), pred.detach().cpu().numpy(), epoch, err)
    return err, f_err, ref_err

######################## VO input #############################

# Custom Dataset
class SuperResDataset(Dataset):
    def __init__(self, input_data, target_data):
        self.input_data = torch.tensor(input_data, dtype=torch.float32)
        self.target_data = torch.tensor(target_data, dtype=torch.float32)

        # # ðŸ”¥ Ensure data is in (BS, 20, Nx, Ny) shape
        # if self.input_data.dim() == 5 and self.input_data.shape[1] == 1:
        #     self.input_data = self.input_data.squeeze(1)  # Remove singleton channel
        #     self.target_data = self.target_data.squeeze(1)  # Remove singleton channel

        print(f"input_data : {self.input_data.shape}")
        print(f"target_data: {self.target_data.shape}")

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        return self.input_data[idx], self.target_data[idx]



# Load Data from .npy Files
def load_data():
    train_inputs = np.load('../TRAIN_PRED.npy',allow_pickle=True)  # Shape: (BS, Nt, Nx, Ny)
    B,T,X,Y = train_inputs.shape
    train_inputs = train_inputs.reshape(-1,1,X,Y)

    train_targets = np.load('../TRAIN_TRUE.npy',allow_pickle=True)  # Shape: (BS, Nt, Nx, Ny)
    train_targets = train_targets.reshape(-1,1,X,Y)

    val_inputs = np.load('../VAL_PRED.npy',allow_pickle=True)  # Shape: (BS, Nt, Nx, Ny)
    val_inputs = val_inputs.reshape(-1,1,X,Y)

    val_targets = np.load('../VAL_TRUE.npy',allow_pickle=True) 
    val_targets = val_targets.reshape(-1,1,X,Y)

    test_inputs = np.load('../TEST_PRED.npy',allow_pickle=True)  # Shape: (BS, Nt, Nx, Ny)
    test_inputs = test_inputs.reshape(-1,1,X,Y)

    test_targets = np.load('../TEST_TRUE.npy',allow_pickle=True) 
    test_targets = test_targets.reshape(-1,1,X,Y)
    
    return train_inputs, train_targets, val_inputs, val_targets, test_inputs, test_targets


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
train_inputs, train_targets, val_inputs, val_targets, test_inputs, test_targets = load_data()

print(f"train_inputs : {train_inputs.shape}")
print(f"train_targets: {train_targets.shape}")
print(f"val_inputs   : {val_inputs.shape}")
print(f"val_targets  : {val_targets.shape}")
print(f"test_inputs  : {test_inputs.shape}")
print(f"test_targets : {test_targets.shape}")

os.makedirs("Params", exist_ok=True)
os.makedirs("power_spectrum", exist_ok=True)


# Create DataLoaders
batch_size = 20

print(f"Train Dataset prep")
train_dataset = SuperResDataset(train_inputs, train_targets)
print(f"Val Dataset prep")
val_dataset = SuperResDataset(val_inputs, val_targets)
print(f"Test Dataset prep")
test_dataset = SuperResDataset(test_inputs, test_targets)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_inputs : (3980, 1, 128, 256)
train_targets: (3980, 1, 128, 256)
val_inputs   : (480, 1, 128, 256)
val_targets  : (480, 1, 128, 256)
test_inputs  : (480, 1, 128, 256)
test_targets : (480, 1, 128, 256)
Train Dataset prep
input_data : torch.Size([3980, 1, 128, 256])
target_data: torch.Size([3980, 1, 128, 256])
Val Dataset prep
input_data : torch.Size([480, 1, 128, 256])
target_data: torch.Size([480, 1, 128, 256])
Test Dataset prep
input_data : torch.Size([480, 1, 128, 256])
target_data: torch.Size([480, 1, 128, 256])


In [4]:
model = VQVAE(in_channels=1, hidden_channels=21, embedding_dim=64, num_embeddings=256, commitment_cost=0.25)
path_model = f'Params/best_model.pt'
model.load_state_dict(torch.load(path_model))


print(summary(model, input_size=(1,)+train_inputs.shape[1:] ) )

# Adjust the dimensions as per your model's input size
dummy_x = torch.tensor(train_inputs[0:1], dtype=DTYPE, device=device)
dummy_input = dummy_x

# Profile the model
model.eval()
flops = 2 * torchprofile.profile_macs(model, dummy_input)
print(f"FLOPs: {flops:.5e}")


Layer (type:depth-idx)                                  Output Shape              Param #
VQVAE                                                   [1, 1, 128, 256]          --
â”œâ”€Encoder: 1-1                                          [1, 21, 64, 128]          --
â”‚    â””â”€Conv2d: 2-1                                      [1, 21, 64, 128]          357
â”‚    â””â”€Sequential: 2-2                                  [1, 21, 64, 128]          --
â”‚    â”‚    â””â”€ResidualBlockCA: 3-1                        [1, 21, 64, 128]          8,044
â”‚    â”‚    â””â”€ResidualBlockCA: 3-2                        [1, 21, 64, 128]          8,044
â”‚    â””â”€Conv2d: 2-3                                      [1, 42, 32, 64]           14,154
â”‚    â””â”€Sequential: 2-4                                  [1, 42, 32, 64]           --
â”‚    â”‚    â””â”€ResidualBlockCA: 3-3                        [1, 42, 32, 64]           32,048
â”‚    â”‚    â””â”€ResidualBlockCA: 3-4                        [1, 42, 32, 64



In [5]:
from matcho import Unet2D

with open('../Par.pkl', 'rb') as f:
    Par_no = pickle.load(f)

no = Unet2D(dim=16, Par=Par_no, dim_mults=(1, 2, 4, 8)).to(device).to(torch.float32)
path_model = '../models/best_model.pt'
no.load_state_dict(torch.load(path_model))

<All keys matched successfully>

# Inference time

In [6]:
test_data_loader = DataLoader(test_dataset, batch_size=1)

In [7]:
inference_time_ls = []

inp_x = torch.rand(size=(1,2,128,256), dtype=DTYPE, device=device)
inp_t = torch.rand(size=(1,), dtype=DTYPE, device=device)

no.eval()
model.eval()

for i in range(15):
    begin_time = time.time()
    with torch.no_grad():
        no_pred = no(inp_x, inp_t)
        no_vae  = model(no_pred)
        # for x, y_true in test_data_loader:
        #     y_pred = model(x.to(device))
        #     break

    end_time = time.time()
    inference_time = end_time - begin_time
    print(f"Inference time: {inference_time:.5f}")
    inference_time_ls.append(inference_time)

print()
print(f"mean: {np.mean(inference_time_ls[5:])}")

Inference time: 0.16827
Inference time: 0.01427
Inference time: 0.01396
Inference time: 0.01394
Inference time: 0.01398
Inference time: 0.01403
Inference time: 0.01393
Inference time: 0.01396
Inference time: 0.01398
Inference time: 0.01402
Inference time: 0.01392
Inference time: 0.01398
Inference time: 0.01391
Inference time: 0.01391
Inference time: 0.01396

mean: 0.013959908485412597


# PeakVRAM

In [8]:
torch.backends.cudnn.benchmark = False  # keep runs reproducible

no.eval()
model.eval()


# Warmup
with torch.no_grad():
    no_pred = no(inp_x, inp_t)
    no_vae  = model(no_pred)


torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats() 
with torch.no_grad():
    no_pred = no(inp_x, inp_t)
    no_vae  = model(no_pred)

torch.cuda.synchronize()


# ---- Read peaks (bytes) and report in GB ----
peak_alloc_GB   = torch.cuda.max_memory_allocated()  / 1e9
peak_resvd_GB   = torch.cuda.max_memory_reserved()   / 1e9
print(f"Peak VRAM (allocated): {peak_alloc_GB:.4f} GB")
print(f"Peak VRAM (reserved) : {peak_resvd_GB:.4f} GB")
print("Config: batch=1, dtype=", DTYPE, ", device=", device)

Peak VRAM (allocated): 0.2982 GB
Peak VRAM (reserved) : 0.4404 GB
Config: batch=1, dtype= torch.float32 , device= cuda


# Sanity Check

In [9]:
recon_criterion = nn.MSELoss()

y_true_ls = []
y_pred_ls = []

# Validation Step
model.eval()
total_val_loss = 0
spec_err = 0.0
field_err = 0.0
ref_err = 0.0
plot_flag = False
with torch.no_grad():
    for x_lr, x_hr in val_loader:
        x_lr, x_hr = x_lr.to(device), x_hr.to(device)
        outputs, vq_loss = model(x_lr)
        recon_loss = recon_criterion(outputs, x_hr)
        loss = recon_loss + vq_loss
        
        total_val_loss += loss.item()

        s_err, f_err, r_err = error_metric(x_lr, outputs, x_hr, 0, is_plot=plot_flag)
        spec_err += s_err.item()
        field_err += f_err.item()
        ref_err += r_err.item()
        plot_flag = False

        y_true_ls.append(x_hr.detach().cpu().numpy())
        y_pred_ls.append(outputs.detach().cpu().numpy())

avg_val_loss = total_val_loss / len(val_loader)
spec_err /= len(val_loader)
field_err /= len(val_loader)
ref_err /= len(val_loader)
print(f" Val Loss: {avg_val_loss:.4e}, spec err: {spec_err:.4e}, field err: {field_err:.4e}, ref err: {ref_err:.4e}" )

VAL_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)
VAL_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)

print(f"VAL_TRUE: {VAL_TRUE.shape}, DTYPE: {VAL_TRUE.dtype}")
print(f"VAL_PRED: {VAL_PRED.shape}, DTYPE: {VAL_PRED.dtype}")



y_true_ls = []
y_pred_ls = []

# Test Step
model.eval()
total_test_loss = 0
spec_err = 0.0
field_err = 0.0
ref_err = 0.0
plot_flag = False
with torch.no_grad():
    for x_lr, x_hr in test_loader:
        x_lr, x_hr = x_lr.to(device), x_hr.to(device)
        outputs, vq_loss = model(x_lr)
        recon_loss = recon_criterion(outputs, x_hr)
        loss = recon_loss + vq_loss
        
        total_test_loss += loss.item()

        s_err, f_err, r_err = error_metric(x_lr, outputs, x_hr, 0, is_plot=plot_flag)
        spec_err += s_err.item()
        field_err += f_err.item()
        ref_err += r_err.item()
        plot_flag = False

        y_true_ls.append(x_hr.detach().cpu().numpy())
        y_pred_ls.append(outputs.detach().cpu().numpy())

avg_test_loss = total_test_loss / len(test_loader)
spec_err /= len(test_loader)
field_err /= len(test_loader)
ref_err /= len(test_loader)
print(f" Test Loss: {avg_test_loss:.4e}, spec err: {spec_err:.4e}, field err: {field_err:.4e}, ref err: {ref_err:.4e}" )

TEST_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)
TEST_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)

print(f"TEST_TRUE: {TEST_TRUE.shape}, DTYPE: {TEST_TRUE.dtype}")
print(f"TEST_PRED: {TEST_PRED.shape}, DTYPE: {TEST_PRED.dtype}")

 Val Loss: 2.1493e-03, spec err: 2.7974e-01, field err: 1.4656e-01, ref err: 1.2629e-01
VAL_TRUE: (96, 5, 128, 256), DTYPE: float32
VAL_PRED: (96, 5, 128, 256), DTYPE: float32
 Test Loss: 2.2392e-03, spec err: 2.8998e-01, field err: 1.4820e-01, ref err: 1.2915e-01
TEST_TRUE: (96, 5, 128, 256), DTYPE: float32
TEST_PRED: (96, 5, 128, 256), DTYPE: float32


In [10]:
np.save("TEST_TRUE.npy", TEST_TRUE)
np.save("TEST_PRED.npy", TEST_PRED)