In [1]:
import os
import sys
sys.path.append(os.path.abspath("..")) 

from model import *

import math
import time
import datetime
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from tcunet import Unet2D
# from YourDataset import YourDataset  # Import your custom dataset here
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary
import torchprofile

import pickle

torch.manual_seed(23)

scaler = GradScaler()

DTYPE = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "serif"

import scipy.stats as stats

Using device: cuda


In [2]:
def compute_power(true, pred, inp):
    BS, nt, nx, ny = true.shape
    
    # Compute the Fourier transforms and amplitude squared for both true and pred
    fourier_true = torch.fft.fftn(true, dim=(-2, -1))
    fourier_pred = torch.fft.fftn(pred, dim=(-2, -1))
    fourier_inp  = torch.fft.fftn(inp , dim=(-2, -1))

    # Get the squared amplitudes
    amplitudes_true = torch.abs(fourier_true) #** 2
    amplitudes_pred = torch.abs(fourier_pred) #** 2
    amplitudes_inp = torch.abs(fourier_inp) #** 2

    # Create the k-frequency grids
    kfreq_y = torch.fft.fftfreq(ny) * ny
    kfreq_x = torch.fft.fftfreq(nx) * nx
    kfreq2D_x, kfreq2D_y = torch.meshgrid(kfreq_x, kfreq_y, indexing='ij')
    
    # Compute the wavenumber grid
    knrm = torch.sqrt(kfreq2D_x ** 2 + kfreq2D_y ** 2).to(true.device)
    
    # Define the bins for the wavenumber
    kbins = torch.arange(0.5, nx // 2 + 1, 1.0, device=true.device)
    
    # Digitize knrm to bin indices
    knrm_flat = knrm.flatten()
    bin_indices = torch.bucketize(knrm_flat, kbins)

    # Reshape and flatten the amplitudes
    amplitudes_true_flat = amplitudes_true.view(BS, nt, nx * ny)
    amplitudes_pred_flat = amplitudes_pred.view(BS, nt, nx * ny)
    amplitudes_inp_flat  = amplitudes_inp.view(BS, nt, nx * ny)

    # Initialize Abins
    Abins_true = torch.zeros((BS, nt, len(kbins) - 1), device=true.device)
    Abins_pred = torch.zeros((BS, nt, len(kbins) - 1), device=pred.device)
    Abins_inp  = torch.zeros((BS, nt, len(kbins) - 1), device= inp.device)

    # Vectorized binning: sum up the values in each bin
    for bin_idx in range(1, len(kbins)):
        mask = (bin_indices == bin_idx).unsqueeze(0).unsqueeze(0)  # Create a mask for each bin
        Abins_true[:, :, bin_idx - 1] = (amplitudes_true_flat * mask).sum(dim=-1) / mask.sum(dim=-1)
        Abins_pred[:, :, bin_idx - 1] = (amplitudes_pred_flat * mask).sum(dim=-1) / mask.sum(dim=-1)
        Abins_inp[:,  :, bin_idx - 1] = (amplitudes_inp_flat  * mask).sum(dim=-1) / mask.sum(dim=-1)

    # Scale the binned amplitudes
    scaling_factor = torch.pi * (kbins[1:] ** 2 - kbins[:-1] ** 2)
    Abins_true *= scaling_factor
    Abins_pred *= scaling_factor
    Abins_inp  *= scaling_factor

    return Abins_true, Abins_pred, Abins_inp

def plot_power_spectrum(power_inp, power_true, power_pred, inp, true, pred, epoch, err):
    f = 2
    fig, axes = plt.subplots(1, 4, figsize=(4*f, 1*f))

    sample_id=8
    t_id = 0
    for i in range(1):
        x = torch.arange(true.shape[-2]//2)
        axes[0].loglog(x, power_true[sample_id,t_id], label='true', c='black')
        axes[0].loglog(x, power_inp[sample_id,t_id], label='NO', c='blue')
        axes[0].loglog(x, power_pred[sample_id,t_id], label='adv. NO', c='red')
        # axes[i].set_title(f"t: {0}")
        axes[0].set_xlabel(r'$k$')
        if i==0:
            axes[0].legend()
        if i==0:
            axes[0].set_ylabel(r'$P(k)$')
    

    inp_sample = inp[sample_id, t_id]
    true_sample = true[sample_id, t_id]
    pred_sample = pred[sample_id, t_id]
    vmin, vmax = true_sample.min(), true_sample.max()
    im1 = axes[1].imshow(true_sample, vmin=vmin, vmax=vmax, cmap=CMAP)
    axes[1].set_title("True")
    axes[1].set_xticks([])
    axes[1].set_yticks([])

    im = axes[2].imshow(inp_sample, vmin=vmin, vmax=vmax, cmap=CMAP)
    axes[2].set_title("NO")
    axes[2].set_xticks([])
    axes[2].set_yticks([])

    im = axes[3].imshow(pred_sample, vmin=vmin, vmax=vmax, cmap=CMAP)
    axes[3].set_title("adv NO")
    axes[3].set_xticks([])
    axes[3].set_yticks([])
    fig.colorbar(im1, ax=axes[3])
    plt.tight_layout()


    fig.suptitle(f"Epoch: {epoch}, MSE: {err:.2e}", fontsize=22, y=1.2)
    plt.savefig(f"power_spectrum/{epoch}.png", dpi=150, bbox_inches='tight')
    plt.close()

def error_metric(inp, pred,true, epoch, Par, is_plot=True):
    power_inp, power_true, power_pred = compute_power(inp, true, pred)
    err = torch.mean( (torch.log(power_true)-torch.log(power_pred) )**2 )
    f_err = torch.norm(true-pred, p=2)/torch.norm(true, p=2)
    ref_err = torch.norm(true-inp, p=2)/torch.norm(true, p=2)
    if is_plot:
        plot_power_spectrum(power_inp.detach().cpu().numpy(), power_true.detach().cpu().numpy(), power_pred.detach().cpu().numpy(), inp.detach().cpu().numpy(), true.detach().cpu().numpy(), pred.detach().cpu().numpy(), epoch, err)
    return err, f_err, ref_err


# Define your custom loss function here
class CustomLoss(nn.Module):
    def __init__(self, Par):
        super(CustomLoss, self).__init__()
        self.Par = Par

    def forward(self, y_pred, y_true):
        y_true = (y_true - self.Par["out_shift"])/self.Par["out_scale"]
        y_pred = (y_pred - self.Par["out_shift"])/self.Par["out_scale"]
        loss = torch.norm(y_true-y_pred, p=2)/torch.norm(y_true, p=2)
        return loss

class YourDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_sample = self.x[idx]
        y_sample = self.y[idx]

        if self.transform:
            x_sample, y_sample = self.transform(x_sample, y_sample)

        return x_sample, y_sample


def preprocess(x,y):
    # x,y - [bs, nt, nx, ny]

    B,T,X,Y = x.shape
    x = x.reshape(-1,1,X,Y)
    y = y.reshape(-1,1,X,Y)

    print(f"x: {x.shape}")
    print(f"y: {y.shape}")

    return x,y


In [3]:
temp="TRAIN"
x_train = np.load(f"../{temp}_PRED.npy")
y_train = np.load(f"../{temp}_TRUE.npy")

temp="VAL"
x_val = np.load(f"../{temp}_PRED.npy")
y_val = np.load(f"../{temp}_TRUE.npy")

temp="TEST"
x_test = np.load(f"../{temp}_PRED.npy")
y_test = np.load(f"../{temp}_TRUE.npy")

inp_min = np.min(x_train[:,0])
inp_max = np.max(x_train[:,0])
out_min = np.min(y_train[:,0])
out_max = np.max(y_train[:,0])

print("Train")
x_train, y_train = preprocess(x_train, y_train)
print("Val")
x_val, y_val = preprocess(x_val, y_val)
print("Test")
x_test, y_test = preprocess(x_test, y_test)

Par = {}
Par["inp_shift"] = inp_min
Par["inp_scale"] = inp_max - inp_min
Par["out_shift"] = out_min
Par["out_scale"] = out_max - out_min
Par["nf"] = x_train.shape[1]

# Create custom datasets
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

x_val_tensor   = torch.tensor(x_val,   dtype=torch.float32)
y_val_tensor   = torch.tensor(y_val,   dtype=torch.float32)

x_test_tensor  = torch.tensor(x_test,  dtype=torch.float32)
y_test_tensor  = torch.tensor(y_test,  dtype=torch.float32)

train_dataset = YourDataset(x_train_tensor, y_train_tensor)
val_dataset = YourDataset(x_val_tensor, y_val_tensor)
test_dataset = YourDataset(x_test_tensor, y_test_tensor)

# Define data loaders
train_batch_size = 50
val_batch_size   = 50
test_batch_size  = 50
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size)

print(Par)

Train
x: (380, 1, 128, 256)
y: (380, 1, 128, 256)
Val
x: (30, 1, 128, 256)
y: (30, 1, 128, 256)
Test
x: (30, 1, 128, 256)
y: (30, 1, 128, 256)
{'inp_shift': 0.031219482, 'inp_scale': 0.65286255, 'out_shift': 0.05078125, 'out_scale': 0.72265625, 'nf': 1}


In [4]:
model = GeneratorRRDB(Par["nf"], Par).to(device).to(torch.float32) #Unet2D(dim=16, Par=Par, dim_mults=(1, 2, 4, 8)).to(device).to(torch.float32)

path_model = f'saved_models/best_model.pth'
model.load_state_dict(torch.load(path_model))

print(summary(model, input_size=(1,)+x_train.shape[1:] ) )

# Adjust the dimensions as per your model's input size
dummy_x = x_train_tensor[0:1].to(device)
# dummy_t = t_train_tensor[0:1].to(device)
dummy_input = dummy_x

# Profile the model
model.eval()
flops = 2 * torchprofile.profile_macs(model, dummy_input)
print(f"FLOPs: {flops:.5e}")

# Define loss function and optimizer
criterion = CustomLoss(Par)

Layer (type:depth-idx)                             Output Shape              Param #
GeneratorRRDB                                      [1, 1, 128, 256]          --
├─Conv2d: 1-1                                      [1, 32, 128, 256]         320
├─Sequential: 1-2                                  [1, 32, 128, 256]         --
│    └─ResidualInResidualDenseBlock: 2-1           [1, 32, 128, 256]         --
│    │    └─Sequential: 3-1                        [1, 32, 128, 256]         415,200
│    └─ResidualInResidualDenseBlock: 2-2           [1, 32, 128, 256]         --
│    │    └─Sequential: 3-2                        [1, 32, 128, 256]         415,200
│    └─ResidualInResidualDenseBlock: 2-3           [1, 32, 128, 256]         --
│    │    └─Sequential: 3-3                        [1, 32, 128, 256]         415,200
│    └─ResidualInResidualDenseBlock: 2-4           [1, 32, 128, 256]         --
│    │    └─Sequential: 3-4                        [1, 32, 128, 256]         415,200
├─Conv2d: 1-3 

In [5]:
from tcunet import Unet2D

with open('../Par.pkl', 'rb') as f:
    Par_no = pickle.load(f)

no = Unet2D(dim=16, Par=Par_no, dim_mults=(1, 2, 4, 8)).to(device).to(torch.float32)
path_model = '../models/best_model.pt'
no.load_state_dict(torch.load(path_model))

<All keys matched successfully>

# Inference time

In [6]:
test_data_loader = DataLoader(test_dataset, batch_size=1)

In [7]:
inference_time_ls = []

inp_x = torch.rand(size=(1,2,128,256), dtype=DTYPE, device=device)
inp_t = torch.rand(size=(1,), dtype=DTYPE, device=device)

no.eval()
model.eval()

for i in range(15):
    begin_time = time.time()
    with torch.no_grad():
        no_pred = no(inp_x, inp_t)
        no_gan  = model(no_pred)
        # for x, y_true in test_data_loader:
        #     y_pred = model(x.to(device))
        #     break

    end_time = time.time()
    inference_time = end_time - begin_time
    print(f"Inference time: {inference_time:.5f}")
    inference_time_ls.append(inference_time)

print()
print(f"mean: {np.mean(inference_time_ls[5:])}")

Inference time: 0.21941
Inference time: 0.01355
Inference time: 0.01404
Inference time: 0.01407
Inference time: 0.01419
Inference time: 0.01408
Inference time: 0.01414
Inference time: 0.01407
Inference time: 0.01413
Inference time: 0.01407
Inference time: 0.01417
Inference time: 0.01406
Inference time: 0.01410
Inference time: 0.01409
Inference time: 0.01417

mean: 0.014107966423034668


# PeakVRAM

In [8]:
torch.backends.cudnn.benchmark = False  # keep runs reproducible

no.eval()
model.eval()


# Warmup
with torch.no_grad():
    no_pred = no(inp_x, inp_t)
    no_gan  = model(no_pred)


torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats() 
with torch.no_grad():
    no_pred = no(inp_x, inp_t)
    no_gan  = model(no_pred)

torch.cuda.synchronize()


# ---- Read peaks (bytes) and report in GB ----
peak_alloc_GB   = torch.cuda.max_memory_allocated()  / 1e9
peak_resvd_GB   = torch.cuda.max_memory_reserved()   / 1e9
print(f"Peak VRAM (allocated): {peak_alloc_GB:.4f} GB")
print(f"Peak VRAM (reserved) : {peak_resvd_GB:.4f} GB")
print("Config: batch=1, dtype=", DTYPE, ", device=", device)

Peak VRAM (allocated): 0.2964 GB
Peak VRAM (reserved) : 1.3044 GB
Config: batch=1, dtype= torch.float32 , device= cuda


# Sanity Check

In [9]:
y_true_ls = []
y_pred_ls = []

model.eval()
val_loss = 0.0
spec_err = 0.0
field_err = 0.0
ref_err = 0.0
plot_flag = False
with torch.no_grad():
    for x, y_true in val_loader:
        if True:
            y_pred = model(x.to(device))
            loss   = criterion(y_pred, y_true.to(device))
        s_err, f_err, r_err = error_metric(x.to(device), y_pred, y_true.to(device), 0, Par, plot_flag)
        val_loss += loss
        spec_err += s_err.item()
        field_err += f_err.item()
        ref_err += r_err.item()
        y_true_ls.append(y_true.detach().cpu().numpy())
        y_pred_ls.append(y_pred.detach().cpu().numpy())

val_loss /= len(val_loader)
spec_err /= len(val_loader)
field_err /= len(val_loader)
ref_err /= len(val_loader)
print(f"Val Loss: {val_loss:.4e}")
print(f" Val Loss: {val_loss:.4e}, spec err: {spec_err:.4e}, field err: {field_err:.4e}, ref err: {ref_err:.4e}" )

VAL_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)
VAL_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)

print(f"VAL_TRUE: {VAL_TRUE.shape}, DTYPE: {VAL_TRUE.dtype}")
print(f"VAL_PRED: {VAL_PRED.shape}, DTYPE: {VAL_PRED.dtype}")



y_true_ls = []
y_pred_ls = []

model.eval()
test_loss = 0.0
with torch.no_grad():
    for x, y_true in test_loader:
        if True:
            y_pred = model(x.to(device) )
            loss   = criterion(y_pred, y_true.to(device))
        test_loss += loss.item()
        y_true_ls.append(y_true.detach().cpu().numpy())
        y_pred_ls.append(y_pred.detach().cpu().numpy())

test_loss /= len(test_loader)
print(f"Test Loss: {test_loss:.4e}")

TEST_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)
TEST_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, 5, 128, 256).astype(np.float32)

print(f"TEST_TRUE: {TEST_TRUE.shape}, DTYPE: {TEST_TRUE.dtype}")
print(f"TEST_PRED: {TEST_PRED.shape}, DTYPE: {TEST_PRED.dtype}")

Val Loss: 1.8116e-01
 Val Loss: 1.8116e-01, spec err: 1.7461e-02, field err: 1.5506e-01, ref err: 1.2854e-01
VAL_TRUE: (6, 5, 128, 256), DTYPE: float32
VAL_PRED: (6, 5, 128, 256), DTYPE: float32
Test Loss: 1.7635e-01
TEST_TRUE: (6, 5, 128, 256), DTYPE: float32
TEST_PRED: (6, 5, 128, 256), DTYPE: float32


In [22]:
np.save("TEST_TRUE.npy", TEST_TRUE)
np.save("TEST_PRED.npy", TEST_PRED)