In [1]:
import os
import xarray as xr
import numpy as np
import imageio
import matplotlib.pyplot as plt 

In [None]:
root_dir = '/datadrive/datasets'

def visualize_variable(variable, resolution, year, level=None):
    file_name = '_'.join([variable, str(year), resolution]) + '.nc'
    path = os.path.join(root_dir, resolution, variable, file_name)

    ds = xr.open_dataset(path)
    if len(ds.to_array().shape) == 5: # pressure-level variables
        ds = ds.sel(level=[level])
        np_data = ds.to_array().to_numpy()[0, 0, 0] # get the first hour of the year
    else:
        np_data = ds.to_array().to_numpy()[0, 0] # get the first hour of the year

    fig, ax = plt.subplots(1, 1)
    im1 = ax.imshow(np_data)
    im1.set_cmap(cmap=plt.cm.RdBu)
    fig.tight_layout()
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False
    ) # labels along the bottom edge are off
    plt.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left=False,      # ticks along the bottom edge are off
        right=False,         # ticks along the top edge are off
        labelleft=False
    ) # labels along the bottom edge are off

    if level is not None:
        variable = variable +'_' + str(level)
    save_name = variable + '.png'
    plt.savefig(save_name, bbox_inches='tight')
   
# variables = ['']
visualize_variable('specific_humidity', '1.40625deg', 2015, level=925)

In [26]:
root_dir = '/datadrive/datasets'

def visualize_tokens(variable, resolution, year, level=None, patch_size=32, num_tokens=16):
    file_name = '_'.join([variable, str(year), resolution]) + '.nc'
    path = os.path.join(root_dir, resolution, variable, file_name)

    ds = xr.open_dataset(path)
    if len(ds.to_array().shape) == 5: # pressure-level variables
        ds = ds.sel(level=[level])
        np_data = ds.to_array().to_numpy()[0, 0, 0] # get the first hour of the year
    else:
        np_data = ds.to_array().to_numpy()[0, 0] # get the first hour of the year
        
    h, w = np_data.shape
    count = 0
    for i in range(h // patch_size):
        for j in range(w // patch_size):
            start_h = i*patch_size
            start_w = j*patch_size
            patch = np_data[start_h:start_h+patch_size, start_w:start_w+patch_size]
            
            fig, ax = plt.subplots(1, 1)
            im1 = ax.imshow(patch)
            im1.set_cmap(cmap=plt.cm.RdBu)
            fig.tight_layout()
            plt.tick_params(
                axis='x',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                labelbottom=False
            ) # labels along the bottom edge are off
            plt.tick_params(
                axis='y',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                left=False,      # ticks along the bottom edge are off
                right=False,         # ticks along the top edge are off
                labelleft=False
            ) # labels along the bottom edge are off

            if level is not None:
                variable_ = variable +'_' + str(level)
            else:
                variable_ = variable
            if not os.path.exists(variable_):
                os.makedirs(variable_)
            save_name = 'token_' + str(count + 1) + '.png'
            plt.savefig(os.path.join(variable_, save_name), bbox_inches='tight')
            plt.close(fig)
            
            count += 1
            if count == num_tokens:
                return
            

In [28]:
visualize_tokens('specific_humidity', '1.40625deg', 2015, 850, 32, 16)