In [1]:
import torch
import gc
import numpy as np
import os
import yaml
import torch
import matplotlib.pyplot as plt
from model import DDPM, UNet_new, VQVAE, GraphModel
from ema_pytorch import EMA
from utils import Config
import numpy as np
import random
import torch.nn.functional as F
from setproctitle import setproctitle
from torch_geometric.data import Data
import random
import argparse
from torchmetrics.functional import structural_similarity_index_measure as ssim
import warnings
warnings.filterwarnings('ignore')

def generate_sparse_observations(encoding_indices, data, M=30):
    B, _ = encoding_indices.shape
    H, W = 128, 128
    input_steps = 10

    X = torch.zeros((B * M, input_steps))
    Coord = torch.full((B * M, 2), -1.0)

    for b in range(B):
        indices = encoding_indices[b]
        batch_data = data[b]
        unique_codes = torch.unique(indices)

        for code in unique_codes:
            mask = indices == code
            selected_idx = torch.nonzero(mask).squeeze()
            if selected_idx.dim() == 0:
                selected_idx = selected_idx.unsqueeze(0)
            chosen_idx = selected_idx[torch.randint(0, len(selected_idx), (1,))]
            y, x = divmod(chosen_idx.item(), W)
            x, y = x / float(H), y / float(W)
            index_counter = b * M + code
            Coord[index_counter] = torch.tensor([x, y])
            X[index_counter] = batch_data[chosen_idx].squeeze()
    
    return X, Coord

def generate_edge_index(encoding_indices, M):
    """
    Generate edge_index and edge_weight for sparse observation points based on codebook adjacency.
    
    Returns:
        edge_index: (2, num_edges * B) torch tensor with codebook indices in [0, codebook_size)
        edge_weight: (num_edges * B, 1) torch tensor of normalized weights (may include 0s)
    """
    B, _ = encoding_indices.shape  # (B, 128*128)
    H, W = 128, 128
    codebook_size = M  # Explicit codebook size
    
    # Initialize the lists to collect all edges and their corresponding weights
    edge_list = []
    edge_weights = []

    # Iterate over each batch
    for b in range(B):
        encoding_grid = encoding_indices[b].view(H, W).numpy()

        cooccur = np.zeros((codebook_size, codebook_size), dtype=np.float32)
        neighbor_offsets = [(-1, -1), (-1, 0), (-1, 1),
                            (0, -1),          (0, 1),
                            (1, -1),  (1, 0), (1, 1)]
        # Count co-occurrence for the current batch
        for i in range(H):
            for j in range(W):
                current_code = encoding_grid[i, j]
                for dx, dy in neighbor_offsets:
                    ni, nj = i + dx, j + dy
                    if 0 <= ni < H and 0 <= nj < W:
                        neighbor_code = encoding_grid[ni, nj]
                        if neighbor_code != current_code:
                            cooccur[current_code, neighbor_code] += 1

        np.fill_diagonal(cooccur, 0)
        # Normalize co-occurrence matrix
        non_zero_counts = (cooccur > 0).sum(axis=1, keepdims=True)
        norm_cooccur = cooccur / (non_zero_counts + 1e-8)
        # Append edges and weights for the current batch
        for i in range(codebook_size):
            related_scores = norm_cooccur[i]
            for j in range(codebook_size):
                if i != j: 
                    edge_list.append((i, j))
                    edge_weights.append(related_scores[j])  
    # Convert lists to tensors
    edge_index = torch.tensor(edge_list, dtype=torch.long).T  # (2, num_edges * B)
    edge_weight = torch.tensor(edge_weights, dtype=torch.float32)  # (num_edges * B,)

    # Normalize edge_weight to [0, 1] if not all 0
    if edge_weight.max() > edge_weight.min():
        edge_weight = (edge_weight - edge_weight.min()) / (edge_weight.max() - edge_weight.min())
    edge_weight = edge_weight.unsqueeze(-1)  # (num_edges * B, 1)

    return edge_index, edge_weight

def normalize_to_neg_one_to_one(img):
    # [0.0, 1.0] -> [-1.0, 1.0]
    return img * 2 - 1
def unnormalize_to_zero_to_one(t):
    # [-1.0, 1.0] -> [0.0, 1.0]
    return (t + 1) * 0.5



def reconstruct_from_sparse_batch(Sparse, encoding_indices, sample_step, n ,grid_size=128,  strategy='first'):   
    B_1, C, total_M = Sparse.shape
    H = W = grid_size
    B = int(B_1 / sample_step)
    Sparse_grouped = Sparse.view(B*sample_step, C, -1, n)

    if strategy == 'first':
        Sparse_selected = Sparse_grouped[:, :, :, 0]  # shape: (B*sample_step, C, M)
    elif strategy == 'mean':
        Sparse_selected = Sparse_grouped.mean(dim=3)  # shape: (B*sample_step, C, M)
    else:
        raise ValueError(f"Unsupported strategy: {strategy}")

    # print(encoding_indices.shape)
    encoding_indices = encoding_indices.unsqueeze(1).expand(-1, sample_step, -1, -1).reshape(B*sample_step, C, -1)
    # print("1", encoding_indices.shape)
    # print(Sparse_selected.shape)
    reconstructed_flat = torch.gather(Sparse_selected, dim=2, index=encoding_indices)  # shape: (B*sample_step, C, H*W)

    reconstructed = reconstructed_flat.view(B_1, C, H, W)
    return reconstructed

system = 'sh'
yaml_path = f'config/{system}.yaml'

with open(yaml_path, 'r') as f:
    base_opt = yaml.full_load(f) 

opt = Config(base_opt)
device = torch.device(f"cuda:5")

opt.save_dir += f'/{opt.dataset}/'

diff = DDPM(
    nn_model=UNet_new(**opt.network),
    **opt.diffusion,
    device=device,
)
diff.to(device)

load_epoch = 999
checkpoint_path = os.path.join(opt.save_dir, f"model_{load_epoch}.pth")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
diff.load_state_dict(checkpoint['MODEL'], strict=False)

ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
ema.to(device)
ema.load_state_dict(checkpoint['EMA'], strict=False)
ema_model = ema.ema_model
ema_model.eval()

size = (opt.data["channel"], opt.data['img_resolution'], opt.data['img_resolution']) 

T= opt.vqvae["T"]
hidden_dim = opt.vqvae["hidden_dim"]   
embedding_dim = opt.vqvae["embedding_dim"]  
num_embeddings =  opt.vqvae["num_embeddings"]
vqvae = VQVAE(input_dim=T, hidden_dim=hidden_dim, embedding_dim=embedding_dim, num_embeddings=num_embeddings)
vqvae.load_state_dict(torch.load(f"./log/{system}/vqvae_T_{T}_ae_pretrain_{num_embeddings}_{hidden_dim}_{embedding_dim}.pth"))
vqvae.eval()


latent_feat = opt.grand["latent_feat"]
input_steps = opt.grand["input_steps"]

predictor = GraphModel(input_steps=input_steps, d_in = latent_feat, codebook_size=num_embeddings).to(device)
predictor.load_state_dict(torch.load(f"./log/{system}/grand_input_{T}_{latent_feat}_1.pth"))

# ===================
# load data
# ===================
x_test = np.load(f'./data/{system}/uv_test.npy')  # Shape: (num_tra, steps, channels, 128, 128)
x_test = torch.tensor(x_test, dtype=torch.float32)  

x_train = np.load(f'./data/{system}/uv.npy')  # Shape: (num_tra, steps, channels, 128, 128)
x_train = torch.tensor(x_train, dtype=torch.float32)  

# Normalize data (Min-Max Scaling)
xmin = x_train.amin(dim=(0, 1, 3, 4), keepdim=True)  # Min over num_tra, steps, spatial dimensions
xmax = x_train.amax(dim=(0, 1, 3, 4), keepdim=True)  # Max over num_tra, steps, spatial dimensions
del x_train
test_data = (x_test - xmin) / (xmax - xmin)  # Normalized data

batch_size = 1   
sample_step = opt.sample["sample_step"]
channel = opt.data["channel"]
test_tra = 1
pred_steps = 80
start_step = 2*T   # >= T
t_input = torch.tensor([1.0])
change_step = opt.sample["change_step"]

truth = test_data[:test_tra , start_step:start_step + pred_steps, :,:,:]   # (batch_size, T, 2, 128, 128)
predictions = torch.zeros_like(truth)
full_reconstructions = torch.zeros_like(truth)

with torch.no_grad():
    for i in range(0, test_tra, batch_size):    
        for code_step in range(start_step, start_step + pred_steps, change_step):  
            if(code_step == start_step):
                for_codebook_data = test_data[i : i + batch_size, code_step-T:code_step, :,:,:]   # (B, T, 2, 128, 128)
            
            elif(code_step-start_step-T<0): 
               
                for_codebook_data = torch.cat([test_data[i : i + batch_size, - T + code_step:start_step, :,:,:], predictions[i:i+batch_size, :code_step - start_step,:,:,:]],dim = 1)
            else:
                for_codebook_data = predictions[i:i+batch_size, code_step - start_step - T: code_step - start_step,:,:,:]

            single_tra_data_1 = for_codebook_data[:, :,0,:,:].reshape(batch_size, T, 128*128).permute(0,2,1)   # (B, 128*128, T)
            _, _, encoding_indices_1 = vqvae(single_tra_data_1)   # (B, 128*128)
        
            encoding_indices = encoding_indices_1.unsqueeze(1)

            X_1, Coord_1 = generate_sparse_observations(encoding_indices_1, single_tra_data_1, num_embeddings)
            edge_index_1, edge_weight_1 = generate_edge_index(encoding_indices_1, num_embeddings)
        
            
            initial_data_1 = Data(    
                    x = X_1,   # (batch_size * M, T)
                    edge_index = edge_index_1,  # (2, E*batch_size)
                    edge_weight=edge_weight_1,  # (E*batch_size, 1)
                    Coord = Coord_1   #   (M*batch_size, 2)
                )
        
            
            initial_data_1 = initial_data_1.to(device)
        
            
            
        
            sparse_predictions = []
            for step in range(0, change_step, sample_step):
                
                
                if(step != 0):   
                    ddpm_samples = ddpm_samples.reshape(batch_size*sample_step, channel, -1).reshape(batch_size, sample_step, channel, -1).detach().cpu()
                    ddpm_result_1 = ddpm_samples[:, :, 0,:].permute(0,2,1)
                    update_single_tra_data_1 = single_tra_data_1[:, :, sample_step:]
                    print(update_single_tra_data_1.shape)
                    update_single_tra_data_1 = torch.cat((update_single_tra_data_1, ddpm_result_1), dim=2)
                    
                
                    
                    X_1, Coord_1 = generate_sparse_observations(encoding_indices_1, update_single_tra_data_1)  # (B, 128*128, T)

                    initial_data_1.x = X_1.to(device)

                    initial_data_1.Coord = Coord_1.to(device)
        
                
                
                for t in range(sample_step):
                    # print(initial_data_1.x)
                    pred_1 = predictor(initial_data_1, t_input) 
                    sparse_predictions.append(pred_1.detach().cpu())
        

                    initial_data_1.x = initial_data_1.x[:, 1:]    
                    initial_data_1.x = torch.cat((initial_data_1.x, pred_1), dim=1)
                    
                sparse_1 = initial_data_1.x[:,-sample_step:].detach().cpu().reshape(batch_size, -1, sample_step).permute(0,2,1).reshape(batch_size*sample_step, -1)
            
                Sparse = sparse_1.unsqueeze(1)   # (B*sample_step, C, M)
                # print(Sparse.shape)
                reconstructed = reconstruct_from_sparse_batch(Sparse = Sparse, encoding_indices=encoding_indices, sample_step = sample_step, n = 1)   #(B*sample_step, C, H, w)
                full_reconstructions[i:i+batch_size, code_step - start_step + step:code_step - start_step + step + sample_step] = reconstructed.reshape(batch_size, sample_step, channel, 128, 128).detach().cpu()
                
                with torch.enable_grad():
                    # ddpm_samples = ema_model.ddim_guided_sample_full_sh(n_sample = batch_size*sample_step, size = size, steps = opt.sample["ddim_step"], eta = opt.sample["eta"], zeta_obs = opt.sample["zeta_obs"], zeta_pde = opt.sample["zeta_pde"], ratio = opt.sample["ratio"], reconstructed = reconstructed, data_opt = opt.data, notqdm=False)  # shape: (B*sample_step, C, H, W)
                    ddpm_samples = ema_model.ddim_sample_from_reconstructed_sh(n_sample = batch_size*sample_step, size = size, steps = 50, eta = opt.sample["eta"], zeta_obs = opt.sample["zeta_obs"], zeta_pde = opt.sample["zeta_pde"], ratio = opt.sample["ratio"], reconstructed = reconstructed, data_opt = opt.data, notqdm=False)  # shape: (B*sample_step, C, H, W)
                    predictions[i:i+batch_size, code_step - start_step + step:code_step - start_step + step + sample_step] = ddpm_samples.reshape(batch_size, sample_step, channel, 128, 128).detach().cpu()
            print(F.mse_loss(predictions, truth, reduction='none').mean().numpy())
                
    def compute_nmse(predictions, targets):
        nmse = torch.mean((predictions - targets) ** 2, dim=(0, 2, 3, 4)) / torch.mean((targets) ** 2, dim=(0, 2, 3, 4))
        return nmse  

    def compute_rmse(predictions, targets):
        rmse = torch.sqrt(torch.mean((predictions - targets) ** 2, dim=(0, 2, 3, 4)))
        return rmse  

    def compute_ssim(predictions, targets):
        ssim_values = []
        for t in range(predictions.shape[1]):  
            ssim_value = ssim(predictions[:,t], targets[:,t], data_range=1.0).item()
            ssim_values.append(ssim_value)
        return torch.tensor(ssim_values)

    nmse = compute_nmse(predictions, truth)    # (pred_steps, )
    my_ssim = compute_ssim(predictions, truth)
    rmse = compute_rmse(predictions, truth)
    print(f"nmse: {nmse.mean()};{nmse.std()}\n; ssim: {my_ssim.mean()};{my_ssim.std()}\n, rmse: {rmse.mean()};{rmse.std()}")
    


100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


0.30087787


100%|██████████| 50/50 [00:04<00:00, 12.39it/s]


0.28862965


100%|██████████| 50/50 [00:04<00:00, 12.46it/s]


0.27476692


100%|██████████| 50/50 [00:04<00:00, 12.42it/s]


0.25902575


100%|██████████| 50/50 [00:04<00:00, 12.48it/s]


0.24171248


100%|██████████| 50/50 [00:04<00:00, 12.48it/s]


0.22328863


100%|██████████| 50/50 [00:04<00:00, 12.46it/s]


0.2040813


100%|██████████| 50/50 [00:04<00:00, 12.46it/s]


0.18425618


100%|██████████| 50/50 [00:04<00:00, 12.45it/s]


0.16387951


100%|██████████| 50/50 [00:04<00:00, 12.45it/s]


0.14295293


100%|██████████| 50/50 [00:03<00:00, 12.53it/s]


0.1214237


100%|██████████| 50/50 [00:04<00:00, 12.47it/s]


0.09922056


100%|██████████| 50/50 [00:04<00:00, 12.47it/s]


0.076278485


100%|██████████| 50/50 [00:04<00:00, 12.44it/s]


0.05257197


100%|██████████| 50/50 [00:04<00:00, 12.49it/s]


0.028146643


100%|██████████| 50/50 [00:04<00:00, 12.47it/s]


0.0031127082
nmse: 0.009317385032773018;0.0033525219187140465
; ssim: 0.9790511131286621;0.006674688309431076
, rmse: 0.05371803790330887;0.015164285898208618
