In [1]:
import h5py 
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, ImageGrid
import matplotlib.gridspec as gridspec
import torch 

def write_image(batch_gt, batch_pred, state_idx, path, cmap='plasma', divider=1):
    """
    Print reference trajectory (1st line) and predicted trajectory (2nd line).
    Skip every N frames (N=divider)
    """
    seq_len, height, width, state_c = batch_gt.shape  # [20, 64, 64, 1]
    t_horizon = seq_len // divider
    new_seq_len = t_horizon * divider 
    batch_gt = batch_gt[:new_seq_len].reshape(divider, t_horizon, height, width, state_c)
    fig = plt.figure(figsize=(t_horizon+1, divider))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(divider, t_horizon),  # creates 2x2 grid of axes
                     axes_pad=0.05, # pad between axes in inch.
                     share_all=True,
                     cbar_location="right",
                     cbar_mode='edge',
                     direction = 'row',
                     cbar_size="10%",
                     cbar_pad=0.15)  
    vmax = np.max(batch_pred[...,state_idx])
    vmin = np.min(batch_pred[...,state_idx])
    for traj in range(divider):
        for t in range(t_horizon):
            # Iterating over the grid returns the Axes.
            im = grid[traj * t_horizon + t].imshow(batch_gt[traj, t, :, :, state_idx], vmax=vmax, vmin=vmin, cmap=cmap, interpolation='none')
            grid[traj * t_horizon + t].set_axis_off()
            if t == t_horizon-1:
                grid[traj * t_horizon + t].cax.colorbar(im)
           
    plt.savefig(path, dpi=72, bbox_inches='tight', pad_inches=0)
    # plt.close(fig)
    fig.clf()
    plt.close(fig)


In [2]:
data_path = '/cluster/nvme4a/whh/dataset/sst/data_zone_9.h5'
h5_data = h5py.File(data_path, 'r')
data = h5_data['mygroup']['mydataset'][:]
data = data.transpose(2,1,0)
print(data.shape)

(4459, 64, 64)


In [3]:
import os 
path = '/cluster/home1/whh/new_repo/coral/coral/utils/data' 
# os.mkdir(path)

In [5]:
img = data[0:1]
print(img.shape)
write_image(img[...,None], img[...,None], 0, path=os.path.join(path, f'data.png'), cmap='twilight_shifted', divider=1)


(1, 64, 64)


In [10]:
fimg = torch.fft.fft2(torch.tensor(img))
print(fimg.shape)

torch.Size([1, 64, 64])


In [15]:
print(fimg.imag.shape)

torch.Size([1, 64, 64])


In [18]:
write_image(fimg.real[...,None].numpy(), fimg.real[...,None].numpy(), 0, path=os.path.join(path, f'fdata_real.png'), cmap='twilight_shifted', divider=1)

In [None]:
write_image(fimg.imag[...,None].numpy(), fimg.imag[...,None].numpy(), 0, path=os.path.join(path, f'fdata_imag.png'), cmap='twilight_shifted', divider=1)