In [1]:
import os
import sys

import math
import time
import datetime
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
# from YourDataset import YourDataset  # Import your custom dataset here
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary
import torchprofile

from utils.architecture import Unet
from utils.diffusion import ElucidatedDiffusion

torch.manual_seed(23)
import pickle

DTYPE = torch.float32

scaler = GradScaler()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "serif"

import scipy.stats as stats
from sklearn.decomposition import TruncatedSVD

Using device: cuda


In [2]:
def error_metric(pred,true, Par):
    #re-normalize
    # true = true*Par['out_scale'] + Par['out_shift']
    # true = true*Par['out_scale'] + Par['out_shift']
    return torch.norm(true-pred, p=2)/torch.norm(true, p=2)

class MyDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_sample = self.x[idx]
        y_sample = self.y[idx]

        if self.transform:
            x_sample, y_sample = self.transform(x_sample, y_sample)

        return x_sample, y_sample
    
def preprocess(x,y, Par):
    x = sliding_window_view(x[:,Par['lb']-1:,:,:], window_shape=Par['lf'], axis=1 ).transpose(0,1,4,2,3).reshape(-1,Par['lf'],Par['nx'], Par['ny'])
    y = sliding_window_view(y[:,Par['lb']-1:,:,:], window_shape=Par['lf'], axis=1 ).transpose(0,1,4,2,3).reshape(-1,Par['lf'],Par['nx'], Par['ny'])

    print('x: ', x.shape)
    print('y: ', y.shape)
    print()
    return x,y


In [3]:
res = 128
begin_time = time.time()
# inp = np.load(f"/oscar/data/gk/voommen/no_diffusion/kolmogrov/res_{res}/matcho/Y_PRED.npy") #low-fidelity
# out = np.load(f"/oscar/data/gk/voommen/no_diffusion/kolmogrov/res_{res}/matcho/Y_TRUE.npy") #high-fidelity
x_train = np.load("../TRAIN_PRED.npy")
y_train = np.load("../TRAIN_TRUE.npy")

x_val = np.load("../VAL_PRED.npy")
y_val = np.load("../VAL_TRUE.npy")

x_test = np.load("../TEST_PRED.npy")
y_test = np.load("../TEST_TRUE.npy")
print(f"Data Loading Time: {time.time() - begin_time:.1f}s")



# # Train-Val-Test Split
# idx1 = int(0.8 * inp.shape[0])
# idx2 = int(0.9 * inp.shape[0])

# x_train = inp[:idx1]
# x_val   = inp[idx1:idx2]
# x_test  = inp[idx2:]

# y_train = out[:idx1]
# y_val   = out[idx1:idx2]
# y_test  = out[idx2:]



inp_min = np.min(x_train)
inp_max = np.max(x_train)
out_min = np.min(y_train)
out_max = np.max(y_train)



Par = {"inp_shift" : torch.tensor(inp_min, dtype=DTYPE, device=device),
       "inp_scale" : torch.tensor(inp_max - inp_min, dtype=DTYPE, device=device),
       "out_shift" : torch.tensor(out_min, dtype=DTYPE, device=device),
       "out_scale" : torch.tensor(out_max - out_min, dtype=DTYPE, device=device),
       "nx"        : x_train.shape[2],
       "ny"        : x_train.shape[3],
       "nf"        : 1,
       "lb"        : 1,
       "lf"        : 1,
       "num_epochs": 1000
       }

# Normalizing the data to [0,1]
shift = Par['inp_shift'].detach().cpu().numpy()
scale = Par['inp_scale'].detach().cpu().numpy()
x_train = (x_train - shift)/scale
x_val = (x_val - shift)/scale
x_test = (x_test - shift)/scale

shift = Par['out_shift'].detach().cpu().numpy()
scale = Par['out_scale'].detach().cpu().numpy()
y_train = (y_train - shift)/scale
y_val = (y_val - shift)/scale
y_test = (y_test - shift)/scale

Par["sigma_data"] = np.std(y_train)

# Traj splitting
begin_time = time.time()
print('\nTrain Dataset')
x_train, y_train = preprocess(x_train, y_train, Par)
print('\nValidation Dataset')
x_val, y_val = preprocess(x_val, y_val, Par)
print('\nTest Dataset')
x_test, y_test = preprocess(x_test, y_test, Par)
print(f"Data Preprocess Time: {time.time() - begin_time:.1f}s")

Par.update({"channels"       : x_train.shape[1],
            "self_condition" : True
            })

print("Par")
with open('Par.pkl', 'wb') as f:
    pickle.dump(Par, f)

x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

x_val_tensor   = torch.tensor(x_val,   dtype=torch.float32)
y_val_tensor   = torch.tensor(y_val,   dtype=torch.float32)

x_test_tensor  = torch.tensor(x_test,  dtype=torch.float32)
y_test_tensor  = torch.tensor(y_test,  dtype=torch.float32)

train_dataset = MyDataset(x_train_tensor, y_train_tensor)
val_dataset = MyDataset(x_val_tensor, y_val_tensor)
test_dataset = MyDataset(x_test_tensor, y_test_tensor)

# Define data loaders
train_batch_size = 100 #16
val_batch_size   = 100
test_batch_size  = 100
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size)

Data Loading Time: 0.4s

Train Dataset
x:  (3980, 1, 128, 256)
y:  (3980, 1, 128, 256)


Validation Dataset
x:  (480, 1, 128, 256)
y:  (480, 1, 128, 256)


Test Dataset
x:  (480, 1, 128, 256)
y:  (480, 1, 128, 256)

Data Preprocess Time: 0.0s
Par


In [4]:
# Define Network Architecture
net = Unet(
    dim = 16,
    dim_mults = (1, 2, 4, 8),
    channels = Par["channels"],
    self_condition = Par["self_condition"],
    flash_attn = True
).to(device).to(torch.float32)
print( summary(net, input_size=((1,)+x_train.shape[1:], (1,)) ) )

model = ElucidatedDiffusion(net,
                                channels = Par["channels"],
                                image_size_h=Par["nx"],
                                image_size_w=Par["ny"],
                                sigma_data=Par["sigma_data"])





class Wrapper(nn.Module):
    def __init__(self, base, num_steps=32):
        super().__init__()
        self.base = base
        self.num_steps = num_steps

    def forward(self, x):
        # Must return a Tensor for torchprofile to trace properly.
        y = self.base.sample(x, num_sample_steps=self.num_steps, seed=None)
        # If your sample() returns a tuple/list, pick the tensor you want to count:
        if isinstance(y, (tuple, list)):
            y = y[0]
        return y


wrapper_model = Wrapper(model)

# Adjust the dimensions as per your model's input size
dummy_x = torch.randn(1, Par["channels"], Par["nx"], Par["ny"],   dtype=DTYPE, device=device)
dummy_input = dummy_x

# Profile the model
wrapper_model.eval()
flops = 2 * torchprofile.profile_macs(wrapper_model, (dummy_x,))
print(f"FLOPs: {flops:.5e}")

# path_model = 'models/best_model.pt'
# path_model = 'models/model_1230.pt'
path_model = 'models/model_160.pt'
model.load_state_dict(torch.load(path_model))

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda
Layer (type:depth-idx)                   Output Shape              Param #
Unet                                     [1, 1, 128, 256]          --
├─Conv2d: 1-1                            [1, 16, 128, 256]         1,584
├─Sequential: 1-2                        [1, 64]                   --
│    └─SinusoidalPosEmb: 2-1             [1, 16]                   --
│    └─Linear: 2-2                       [1, 64]                   1,088
│    └─GELU: 2-3                         [1, 64]                   --
│    └─Linear: 2-4                       [1, 64]                   4,160
├─ModuleList: 1-3                        --                        --
│    └─ModuleList: 2-5                   --                        --
│    │    └─ResnetBlock: 3-1             [1, 16, 128, 256]         6,752
│    │    └─ResnetBlock: 3-2             [1, 16, 128, 256]         6,752
│    │    └─LinearAttention: 3-3         [1, 16, 128

sampling time step: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s]


FLOPs: 5.49957e+11


<All keys matched successfully>

In [5]:
sys.path.append(os.path.abspath("..")) 
from matcho import Unet2D


with open('../Par.pkl', 'rb') as f:
    Par_no = pickle.load(f)

no = Unet2D(dim=16, Par=Par_no, dim_mults=(1, 2, 4, 8)).to(device).to(torch.float32)
path_model = '../models/best_model.pt'
no.load_state_dict(torch.load(path_model))

<All keys matched successfully>

# Inference Time

In [6]:
test_data_loader = DataLoader(test_dataset, batch_size=1)

In [7]:
inference_time_ls = []

inp_x = torch.rand(size=(1,2,128,256), dtype=DTYPE, device=device)
inp_t = torch.rand(size=(1,), dtype=DTYPE, device=device)

no.eval()
model.eval()

for i in range(15):
    begin_time = time.time()
    with torch.no_grad():
        no_pred = no(inp_x, inp_t)
        no_dm   = model.sample(no_pred, num_sample_steps=32)
        # for l_fidel, h_fidel in test_data_loader:
        #     y_pred = model.sample(l_fidel.to(device), num_sample_steps=32)
        #     break
    
    end_time = time.time()
    inference_time = end_time - begin_time
    print(f"Inference time: {inference_time:.5f}")
    inference_time_ls.append(inference_time)

print()
print(f"mean: {np.mean(inference_time_ls[5:])}")

sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.22it/s]


Inference time: 0.72485


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.60it/s]


Inference time: 0.69844


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.58it/s]


Inference time: 0.69859


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.61it/s]


Inference time: 0.69807


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.65it/s]


Inference time: 0.69763


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.72it/s]


Inference time: 0.69649


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.73it/s]


Inference time: 0.69623


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.72it/s]


Inference time: 0.69655


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.57it/s]


Inference time: 0.69867


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.72it/s]


Inference time: 0.69646


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.69it/s]


Inference time: 0.69680


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.64it/s]


Inference time: 0.69765


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.72it/s]


Inference time: 0.69643


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.67it/s]


Inference time: 0.69719


sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.63it/s]

Inference time: 0.69770

mean: 0.6970182180404663





# PeakVRAM

In [8]:
torch.backends.cudnn.benchmark = False  # keep runs reproducible

no.eval()
model.eval()


# Warmup
with torch.no_grad():
    no_pred = no(inp_x, inp_t)
    no_dm   = model.sample(no_pred, num_sample_steps=32)


torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats() 
with torch.no_grad():
    no_pred = no(inp_x, inp_t)
    no_dm   = model.sample(no_pred, num_sample_steps=32)

torch.cuda.synchronize()


# ---- Read peaks (bytes) and report in GB ----
peak_alloc_GB   = torch.cuda.max_memory_allocated()  / 1e9
peak_resvd_GB   = torch.cuda.max_memory_reserved()   / 1e9
print(f"Peak VRAM (allocated): {peak_alloc_GB:.4f} GB")
print(f"Peak VRAM (reserved) : {peak_resvd_GB:.4f} GB")
print("Config: batch=1, dtype=", DTYPE, ", device=", device)

sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.63it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 46.61it/s]

Peak VRAM (allocated): 0.3171 GB
Peak VRAM (reserved) : 0.4048 GB
Config: batch=1, dtype= torch.float32 , device= cuda





In [9]:
l_fidel_ls = []
y_pred_ls = []
y_true_ls = []

model.eval()
test_loss = 0.0
with torch.no_grad():
    for l_fidel, h_fidel in test_loader:
        pred = model.sample(l_fidel.to(device), num_sample_steps=32)
        loss   = error_metric(pred, h_fidel.to(device), Par)
        test_loss += loss.item()
        l_fidel_ls.append(l_fidel)
        y_true_ls.append(h_fidel)
        y_pred_ls.append(pred)

test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4e}')

sampling time step: 100%|██████████| 32/32 [00:06<00:00,  5.07it/s]
sampling time step: 100%|██████████| 32/32 [00:06<00:00,  5.18it/s]
sampling time step: 100%|██████████| 32/32 [00:06<00:00,  5.18it/s]
sampling time step: 100%|██████████| 32/32 [00:06<00:00,  5.18it/s]
sampling time step: 100%|██████████| 32/32 [00:05<00:00,  6.27it/s]

Test Loss: 1.9643e-01





In [10]:
y_true = torch.cat(y_true_ls, dim=0).detach().cpu().numpy().reshape(-1, 5, Par["nx"], Par["ny"])
y_pred = torch.cat(y_pred_ls, dim=0).detach().cpu().numpy().reshape(-1, 5, Par["nx"], Par["ny"])
y_no   = torch.cat(l_fidel_ls, dim=0).detach().cpu().numpy().reshape(-1,5, Par["nx"], Par["ny"])

# Renormalize from [0,1] to actual distribution
y_true = y_true * Par['out_scale'].detach().cpu().numpy() + Par['out_shift'].detach().cpu().numpy()
y_pred = y_pred * Par['out_scale'].detach().cpu().numpy() + Par['out_shift'].detach().cpu().numpy()
y_no   = y_no   * Par['inp_scale'].detach().cpu().numpy() + Par['inp_shift'].detach().cpu().numpy()

print(f"y_true: {y_true.shape}")
print(f"y_pred: {y_pred.shape}")
print(f"y_no  : {y_no.shape}")

np.save("y_pred.npy", y_pred)

y_true: (96, 5, 128, 256)
y_pred: (96, 5, 128, 256)
y_no  : (96, 5, 128, 256)
