In [2]:
import os
import sys

import math
import time
import datetime
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from tcunet import Unet2D
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torchinfo import summary
import torchprofile

import pickle

torch.manual_seed(23)

scaler = GradScaler()

DTYPE = torch.float32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "serif"

import scipy.stats as stats

Using device: cuda


In [4]:
class CustomLoss(nn.Module):
    def __init__(self, Par):
        super(CustomLoss, self).__init__()
        self.Par = Par

    def forward(self, y_pred, y_true):
        y_true = (y_true - self.Par["out_shift"])/self.Par["out_scale"]
        y_pred = (y_pred - self.Par["out_shift"])/self.Par["out_scale"]
        loss = torch.norm(y_true-y_pred, p=2)/torch.norm(y_true, p=2)
        return loss

class YourDataset(Dataset):
    def __init__(self, x, t, y, transform=None):
        self.x = x
        self.t = t
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_sample = self.x[idx]
        t_sample = self.t[idx]
        y_sample = self.y[idx]

        if self.transform:
            x_sample, t_sample, y_sample = self.transform(x_sample, t_sample, y_sample)

        return x_sample, t_sample, y_sample


def preprocess(traj_i, traj_o, Par):
    x = sliding_window_view(traj_i[:,:,:,:], window_shape=Par['lf'], axis=1 ).transpose(0,1,4,2,3).reshape(-1,Par['lf'],Par['nx'], Par['ny'])[:, [0,-1]] # BS, 2, nx, ny
    y = sliding_window_view(traj_o[:,:,:,:], window_shape=Par['lf'], axis=1 ).transpose(0,1,4,2,3).reshape(-1,Par['lf'],Par['nx'], Par['ny'])            # BS, lf, nx, ny
    t = np.linspace(0,1,Par['lf']).reshape(-1,1)

    nt = y.shape[1]
    n_samples = y.shape[0]

    t = np.tile(t, [n_samples,1]).reshape(-1,)                         
    x = np.repeat(x,nt, axis=0)                                  
    y = y.reshape(y.shape[0]*y.shape[1],1,y.shape[2],y.shape[3])  

    print('x: ', x.shape)
    print('y: ', y.shape)
    print('t: ', t.shape)
    print()
    return x,y,t

def combined_scheduler(optimizer, total_epochs, warmup_epochs, last_epoch=-1):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch + 1) / warmup_epochs
        else:
            return 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

def make_images(true, pred, epoch):
    sample_id = -2
    t_id = 0

    CMAP = "gray"
    VMIN = 0
    VMAX = 255


    T = true[sample_id, t_id].detach().cpu().numpy()*256
    P = pred[sample_id, t_id].detach().cpu().numpy()*256

    fig, axes = plt.subplots(1,2, figsize=(20,5))
    axes[0].imshow(T, cmap=CMAP, vmin=VMIN, vmax=VMAX)
    axes[0].set_title("True")
    axes[1].imshow(P, cmap=CMAP, vmin=VMIN, vmax=VMAX)
    axes[1].set_title("Pred")

    plt.tight_layout()


    fig.suptitle(f"Epoch: {epoch}", fontsize=22, y=1.2)
    plt.savefig(f"images/{epoch}.png", dpi=150, bbox_inches='tight')
    plt.close()



In [5]:
begin_time = time.time()
traj_i = np.load(f"../data/lr8_data.npy").astype(np.float32)/256 #[nt, nx, ny]
traj_i = np.expand_dims(traj_i, axis=0) #[1, nt, nx, ny]
traj_o = np.load(f"../data/hr_data.npy").astype(np.float32)/256 #[nt, nx, ny]
traj_o = np.expand_dims(traj_o, axis=0) #[1, nt, nx, ny]

print(f"traj_i: {traj_i.shape}")
print(f"traj_o: {traj_o.shape}")

print(f"Data Loading Time: {time.time() - begin_time:.1f}s")


nsamples = traj_i.shape[1]
idx1 = int(0.8*nsamples)
idx2 = int(0.9*nsamples)

print(idx1, idx2)

traj_i_train = traj_i[:, :idx1]
traj_i_val   = traj_i[:, idx1:idx2]
traj_i_test  = traj_i[:, idx2:]

traj_o_train = traj_o[:, :idx1]
traj_o_val   = traj_o[:, idx1:idx2]
traj_o_test  = traj_o[:, idx2:]

Par = {}
# Par['nt'] = 100 
Par['nx'] = traj_i_train.shape[2]
Par['ny'] = traj_i_train.shape[3]
Par['nf'] = 1
Par['d_emb'] = 128

Par['lb'] = 2
Par['lf'] = 4+1
# Par['temp'] = Par['nt'] - Par['lb'] - Par['lf'] + 2

Par['num_epochs'] = 50 #50

begin_time = time.time()
print('\nTrain Dataset')
x_train, y_train, t_train = preprocess(traj_i_train, traj_o_train, Par)
print('\nValidation Dataset')
x_val, y_val, t_val  = preprocess(traj_i_val, traj_o_val, Par)
print('\nTest Dataset')
x_test, y_test, t_test  = preprocess(traj_i_test, traj_o_test, Par)
print(f"Data Preprocess Time: {time.time() - begin_time:.1f}s")

# sys.exit()

t_min = np.min(t_train)
t_max = np.max(t_train)

Par['inp_scale'] = np.max(x_train) - np.min(x_train)
Par['inp_shift'] = np.min(x_train)
Par['out_scale'] = np.max(y_train) - np.min(y_train)
Par['out_shift'] = np.min(y_train)
Par['t_shift']   = t_min
Par['t_scale']   = t_max - t_min

with open('Par.pkl', 'wb') as f:
    pickle.dump(Par, f)

# sys.exit()
#########################

# Create custom datasets
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
t_train_tensor = torch.tensor(t_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

x_val_tensor   = torch.tensor(x_val,   dtype=torch.float32)
t_val_tensor   = torch.tensor(t_val,   dtype=torch.float32)
y_val_tensor   = torch.tensor(y_val,   dtype=torch.float32)

x_test_tensor  = torch.tensor(x_test,  dtype=torch.float32)
t_test_tensor  = torch.tensor(t_test,  dtype=torch.float32)
y_test_tensor  = torch.tensor(y_test,  dtype=torch.float32)

train_dataset = YourDataset(x_train_tensor, t_train_tensor, y_train_tensor)
val_dataset = YourDataset(x_val_tensor, t_val_tensor, y_val_tensor)
test_dataset = YourDataset(x_test_tensor, t_test_tensor, y_test_tensor)

# Define data loaders
train_batch_size = 20
val_batch_size   = 20
test_batch_size  = 20
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size)

print(Par)

traj_i: (1, 100, 128, 256)
traj_o: (1, 100, 128, 256)
Data Loading Time: 0.0s
80 90

Train Dataset
x:  (380, 2, 128, 256)
y:  (380, 1, 128, 256)
t:  (380,)


Validation Dataset
x:  (30, 2, 128, 256)
y:  (30, 1, 128, 256)
t:  (30,)


Test Dataset
x:  (30, 2, 128, 256)
y:  (30, 1, 128, 256)
t:  (30,)

Data Preprocess Time: 0.0s
{'nx': 128, 'ny': 256, 'nf': 1, 'd_emb': 128, 'lb': 2, 'lf': 5, 'num_epochs': 50, 'inp_scale': 0.7561752, 'inp_shift': -0.0007444823, 'out_scale': 0.7265625, 'out_shift': 0.05078125, 't_shift': 0.0, 't_scale': 1.0}


In [6]:
model = Unet2D(dim=16, Par=Par, dim_mults=(1, 2, 4, 8)).to(device).to(torch.float32)

path_model = 'models/best_model.pt'
model.load_state_dict(torch.load(path_model))

print(summary(model, input_size=((1,)+x_train.shape[1:], (1,)) ) )

# Adjust the dimensions as per your model's input size
dummy_x = x_train_tensor[0:1].to(device)
dummy_t = t_train_tensor[0:1].to(device)
dummy_input = (dummy_x, dummy_t)

# Profile the model
model.eval()
flops = 2 * torchprofile.profile_macs(model, dummy_input)
print(f"FLOPs: {flops:.4e}")

# Define loss function and optimizer
criterion = CustomLoss(Par)

Layer (type:depth-idx)                                  Output Shape              Param #
Unet2D                                                  [1, 1, 128, 256]          --
├─Conv2d: 1-1                                           [1, 16, 128, 256]         3,152
├─Sequential: 1-2                                       [1, 64]                   --
│    └─SinusoidalPosEmb: 2-1                            [1, 16]                   --
│    └─Linear: 2-2                                      [1, 64]                   1,088
│    └─GELU: 2-3                                        [1, 64]                   --
│    └─Linear: 2-4                                      [1, 64]                   4,160
├─ModuleList: 1-3                                       --                        --
│    └─ModuleList: 2-5                                  --                        --
│    │    └─ResnetBlock: 3-1                            [1, 16, 128, 256]         6,784
│    │    └─ResnetBlock: 3-2                    



# Inference Time

In [7]:
test_data_loader = DataLoader(test_dataset, batch_size=1)

In [8]:
inference_time_ls = []

model.eval()

for i in range(15):
    begin_time = time.time()
    with torch.no_grad():
        for x, t, y_true in test_data_loader:
            y_pred = model(x.to(device), t.to(device))
            break
    
    end_time = time.time()
    inference_time = end_time - begin_time
    print(f"Inference time: {inference_time:.5f}")
    inference_time_ls.append(inference_time)

print()
print(f"mean: {np.mean(inference_time_ls[5:])}")

Inference time: 0.01328
Inference time: 0.01048
Inference time: 0.01040
Inference time: 0.01022
Inference time: 0.01019
Inference time: 0.01029
Inference time: 0.01025
Inference time: 0.01024
Inference time: 0.01027
Inference time: 0.01029
Inference time: 0.01017
Inference time: 0.01020
Inference time: 0.01025
Inference time: 0.01020
Inference time: 0.01031

mean: 0.010246634483337402


# PeakVRAM

In [9]:
torch.backends.cudnn.benchmark = False  # keep runs reproducible

model.eval()


# Warmup
with torch.no_grad():
    for x, t, y_true in test_data_loader:
        _ = model(x.to(device), t.to(device))
        break


torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats() 
with torch.no_grad():
    for x, t, y_true in test_data_loader:
        _ = model(x.to(device), t.to(device))
        break

torch.cuda.synchronize()


# ---- Read peaks (bytes) and report in GB ----
peak_alloc_GB   = torch.cuda.max_memory_allocated()  / 1e9
peak_resvd_GB   = torch.cuda.max_memory_reserved()   / 1e9
print(f"Peak VRAM (allocated): {peak_alloc_GB:.4f} GB")
print(f"Peak VRAM (reserved) : {peak_resvd_GB:.4f} GB")
print("Config: batch=1, dtype=", DTYPE, ", device=", device)

Peak VRAM (allocated): 0.2898 GB
Peak VRAM (reserved) : 0.7319 GB
Config: batch=1, dtype= torch.float32 , device= cuda


In [10]:
# Sanity Check

y_true_ls = []
y_pred_ls = []

model.eval()
train_loss = 0.0
with torch.no_grad():
    for x, t, y_true in train_loader:
        with autocast():
            y_pred = model(x.to(device), t.to(device))
            loss   = criterion(y_pred, y_true.to(device))
        train_loss += loss.item()
        y_true_ls.append(y_true.detach().cpu().numpy())
        y_pred_ls.append(y_pred.detach().cpu().numpy())

train_loss /= len(train_loader)
print(f"Train Loss: {train_loss:.4e}")

TRAIN_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, Par['lf'], Par['nx'], Par['ny']).astype(np.float32)
TRAIN_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, Par['lf'], Par['nx'], Par['ny']).astype(np.float32)

print(f"TRAIN_TRUE: {TRAIN_TRUE.shape}, DTYPE: {TRAIN_TRUE.dtype}")
print(f"TRAIN_PRED: {TRAIN_PRED.shape}, DTYPE: {TRAIN_PRED.dtype}")



y_true_ls = []
y_pred_ls = []

model.eval()
val_loss = 0.0
with torch.no_grad():
    for x, t, y_true in val_loader:
        with autocast():
            y_pred = model(x.to(device), t.to(device))
            loss   = criterion(y_pred, y_true.to(device))
        val_loss += loss.item()
        y_true_ls.append(y_true.detach().cpu().numpy())
        y_pred_ls.append(y_pred.detach().cpu().numpy())

val_loss /= len(val_loader)
print(f"Val Loss: {val_loss:.4e}")

VAL_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, Par['lf'], Par['nx'], Par['ny']).astype(np.float32)
VAL_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, Par['lf'], Par['nx'], Par['ny']).astype(np.float32)

print(f"VAL_TRUE: {VAL_TRUE.shape}, DTYPE: {VAL_TRUE.dtype}")
print(f"VAL_PRED: {VAL_PRED.shape}, DTYPE: {VAL_PRED.dtype}")



y_true_ls = []
y_pred_ls = []

model.eval()
test_loss = 0.0
with torch.no_grad():
    for x, t, y_true in test_loader:
        with autocast():
            y_pred = model(x.to(device), t.to(device))
            loss   = criterion(y_pred, y_true.to(device))
        test_loss += loss.item()
        y_true_ls.append(y_true.detach().cpu().numpy())
        y_pred_ls.append(y_pred.detach().cpu().numpy())

test_loss /= len(test_loader)
print(f"Test Loss: {test_loss:.4e}")

TEST_TRUE = np.concatenate(y_true_ls, axis=0).reshape(-1, Par['lf'], Par['nx'], Par['ny']).astype(np.float32)
TEST_PRED = np.concatenate(y_pred_ls, axis=0).reshape(-1, Par['lf'], Par['nx'], Par['ny']).astype(np.float32)

print(f"TEST_TRUE: {TEST_TRUE.shape}, DTYPE: {TEST_TRUE.dtype}")
print(f"TEST_PRED: {TEST_PRED.shape}, DTYPE: {TEST_PRED.dtype}")

Train Loss: 1.4515e-01
TRAIN_TRUE: (76, 5, 128, 256), DTYPE: float32
TRAIN_PRED: (76, 5, 128, 256), DTYPE: float32
Val Loss: 1.4995e-01
VAL_TRUE: (6, 5, 128, 256), DTYPE: float32
VAL_PRED: (6, 5, 128, 256), DTYPE: float32
Test Loss: 1.4307e-01
TEST_TRUE: (6, 5, 128, 256), DTYPE: float32
TEST_PRED: (6, 5, 128, 256), DTYPE: float32


In [11]:
np.save("TRAIN_TRUE.npy", TRAIN_TRUE)
np.save("TRAIN_PRED.npy", TRAIN_PRED)

np.save("VAL_TRUE.npy", VAL_TRUE)
np.save("VAL_PRED.npy", VAL_PRED)

np.save("TEST_TRUE.npy", TEST_TRUE)
np.save("TEST_PRED.npy", TEST_PRED)