In [1]:
import h5py 
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable, ImageGrid
import matplotlib.gridspec as gridspec

def write_image(batch_gt, batch_pred, state_idx, path, cmap='plasma', divider=1):
    """
    Print reference trajectory (1st line) and predicted trajectory (2nd line).
    Skip every N frames (N=divider)
    """
    seq_len, height, width, state_c = batch_gt.shape  # [20, 64, 64, 1]
    t_horizon = seq_len // divider
    new_seq_len = t_horizon * divider 
    batch_gt = batch_gt[:new_seq_len].reshape(divider, t_horizon, height, width, state_c)
    fig = plt.figure(figsize=(t_horizon+1, divider))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(divider, t_horizon),  # creates 2x2 grid of axes
                     axes_pad=0.05, # pad between axes in inch.
                     share_all=True,
                     cbar_location="right",
                     cbar_mode='edge',
                     direction = 'row',
                     cbar_size="10%",
                     cbar_pad=0.15)  
    vmax = np.max(batch_pred[...,state_idx])
    vmin = np.min(batch_pred[...,state_idx])
    for traj in range(divider):
        for t in range(t_horizon):
            # Iterating over the grid returns the Axes.
            im = grid[traj * t_horizon + t].imshow(batch_gt[traj, t, :, :, state_idx], vmax=vmax, vmin=vmin, cmap=cmap, interpolation='none')
            grid[traj * t_horizon + t].set_axis_off()
            if t == t_horizon-1:
                grid[traj * t_horizon + t].cax.colorbar(im)
           
    plt.savefig(path, dpi=72, bbox_inches='tight', pad_inches=0)
    # plt.close(fig)
    fig.clf()
    plt.close(fig)


In [2]:
data_path = '/cluster/nvme4a/whh/dataset/sst/data_zone_1.h5'
h5_data = h5py.File(data_path, 'r')
data = h5_data['mygroup']['mydataset'][:]
data = data.transpose(2,1,0)
print(data.shape)

(4459, 64, 64)


In [3]:
import os 
path = '/cluster/nvme4a/whh/dataset/sst/data_zone_1' 
# os.mkdir(path)

In [4]:
seq = 13*7 
data_len = data.shape[0]
print(seq, data_len)
for idx in range(data_len//seq):
    data_batch = data[idx*seq:(idx+1)*seq]
    write_image(data_batch[...,None], data[...,None], 0, path=os.path.join(path, f'frame_{idx}.png'), cmap='twilight_shifted', divider=7)


91 4459


In [5]:
h5_data.close()

In [6]:
for file_id in range(2,20):
    path = f'/cluster/nvme4a/whh/dataset/sst/data_zone_{file_id}' 
    print(path)
    os.mkdir(path)
    data_path = f'/cluster/nvme4a/whh/dataset/sst/data_zone_{file_id}.h5'
    h5_data = h5py.File(data_path, 'r')
    data = h5_data['mygroup']['mydataset'][:]
    data = data.transpose(2,1,0)
    print(data.shape)

    seq = 13*7 
    data_len = data.shape[0]
    print(seq, data_len)
    for idx in range(data_len//seq):
        data_batch = data[idx*seq:(idx+1)*seq]
        write_image(data_batch[...,None], data[...,None], 0, path=os.path.join(path, f'frame_{idx}.png'), cmap='twilight_shifted', divider=7)

    h5_data.close()

/cluster/nvme4a/whh/dataset/sst/data_zone_2
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_3
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_4
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_5
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_6
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_7
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_8
(4459, 64, 64)
91 4459
/cluster/nvme4a/whh/dataset/sst/data_zone_9
(4459, 64, 64)
91 4459
