In [1]:
import xarray as xr
import torch
from xbatcher import BatchGenerator
import matplotlib.pyplot as plt
from utils.general import load_config
import pandas as pd

config = load_config()

In [2]:
data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})
start_date = "2025-03-01"
end_date = "2025-03-14"
latitude_range = tuple(config["dataset"]["latitude_range"])
longitude_range = tuple(config["dataset"]["longitude_range"])
hr_data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
                longitude=slice(longitude_range[0],longitude_range[1]),
                time=slice(start_date, end_date))

data_vars = list(hr_data.data_vars)

In [3]:
hr_data

# Benchmark

In [4]:
import time
import matplotlib.pyplot as plt
import numpy as np

In [None]:
num_trials = 10  # Number of repetitions
num_vars_list = []
time_avg_list = []
time_std_list = []
size_data = []
batch = 64
for num_vars in range(1, 10):

    data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})
    data_vars = list(data.data_vars)
    start_date = "2025-03-01"
    end_date = "2025-03-05"
    latitude_range = tuple(config["dataset"]["latitude_range"])
    longitude_range = tuple(config["dataset"]["longitude_range"])
    data = data.sel(time=slice(start_date, end_date))
    data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
                    longitude=slice(longitude_range[0],longitude_range[1]))

    selected_vars = data_vars[:num_vars]

    hr_data_subset = data[selected_vars]  # Subset dataset
    batch_generator_hr = BatchGenerator(hr_data_subset, input_dims={
        "time": batch,
        "latitude": hr_data_subset.sizes["latitude"],
        "longitude": hr_data_subset.sizes["longitude"]
    })

    times = []

    for _ in range(num_trials):  # Run multiple trials
        start_time = time.time()

        for batch in batch_generator_hr:

            data_batch = batch.load()

        elapsed_time = time.time() - start_time
        times.append(elapsed_time)
        size = data_batch.nbytes / (1024*1024)
        print(size)

        data_batch = 0

    avg_time = np.mean(times)
    std_time = np.std(times)

    num_vars_list.append(num_vars)
    time_avg_list.append(avg_time)
    time_std_list.append(std_time)
    size_data.append(size)
    print(f"Num Vars: {num_vars}, Avg Time: {avg_time:.4f} sec, Std Dev: {std_time:.4f} sec")

In [None]:
# Plot the results with error bars
plt.figure(figsize=(8, 5))
plt.errorbar(num_vars_list, time_avg_list, yerr=time_std_list, fmt='-o', capsize=4, label="Avg Time ± Std Dev")
plt.xlabel("Number of Data Variables")
plt.ylabel("Time to Load (seconds)")
plt.title("Time to Load vs Number of Data Variables")
plt.grid()
plt.legend()
plt.savefig("load_vs_parameters.png")
plt.show()

In [None]:
max_bandwidth_Mbps = 25000
df = pd.DataFrame(data={"climate_variables":num_vars_list,
                        "time_avg":time_avg_list,
                        "time_std":time_std_list,
                        "size_data":size_data})

df["batch"] = 64
df["fps"] = (df['climate_variables'] *  df["batch"] / df["time_avg"])
df['bandwidth_MBps'] = df['size_data'] / df["time_avg"]
df['bandwidth_Mbps'] = df['size_data'] / df["time_avg"] * 8
df['max_bandwidth_Mbps'] = max_bandwidth_Mbps

In [None]:
df["max_fps"] = df['max_bandwidth_Mbps'] * df["fps"] / df['bandwidth_Mbps']

In [None]:
df

## Generate Animation at Netflix speed

In [17]:
data = xr.open_dataset(
    config["dataset"]["hr_zarr_url"],
    engine="zarr", storage_options={"client_kwargs": {"trust_env": "true"}},
    chunks={})
start_date = "2025-03-01"
end_date = "2025-03-15"
latitude_range = tuple(config["dataset"]["latitude_range"])
longitude_range = tuple(config["dataset"]["longitude_range"])
hr_data = data.sel(latitude=slice(latitude_range[0],latitude_range[1]),
                longitude=slice(longitude_range[0],longitude_range[1]),
                time=slice(start_date, end_date))

data_vars = list(hr_data.data_vars)

hr_data

## Basic Viz

In [18]:
import os
import yaml
from IPython.display import HTML
from loguru import logger
from matplotlib.animation import FuncAnimation
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import numpy as np
import re

In [19]:
def find_coord_name(coord_names, pattern):
    """
    Function to find coordinate names using regex.
    """
    for name in coord_names:
        if pattern.search(name):
            return name
    return None


class Visualizer:

    @staticmethod
    def get_metadata(data):
        """
        Extract relevant metadata (title, unit, colorbar label) from dataset attributes.
        """
        long_name = data.attrs.get("long_name", "Unknown Variable")
        standard_name = data.attrs.get("standard_name", long_name)
        units = data.attrs.get("units", "")
        title = f"{long_name} ({units})" if units else long_name
        colorbar_label = f"{standard_name} [{units}]" if units else standard_name
        return title, colorbar_label


    @staticmethod
    def generate_animation(data, cmap="YlOrRd", show_coastlines=True, fps=30):
        """
        Generate an animation with a customizable FPS.

        Parameters:
        - data: xarray DataArray to animate
        - cmap: Colormap for visualization (default: 'YlOrRd')
        - show_coastlines: Show/hide coastlines (default: True)
        - fps: Frames per second for the animation (default: 30)
        """
        lat_pattern = re.compile(r'lat(itude)?', re.IGNORECASE)
        lon_pattern = re.compile(r'lon(gitude)?', re.IGNORECASE)
        coord_names = data.coords.keys()

        lat_name = find_coord_name(coord_names, lat_pattern)
        lon_name = find_coord_name(coord_names, lon_pattern)

        if lat_name is None or lon_name is None:
            raise ValueError("Latitude and/or Longitude coordinates not found.")

        title, colorbar_label = Visualizer.get_metadata(data)

        fig = plt.figure(figsize=(19.2, 10.8), dpi=100)  # Keep consistent 1080p resolution
        ax = plt.axes(projection=ccrs.PlateCarree())

        if show_coastlines:
            ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
            ax.add_feature(cfeature.BORDERS, linewidth=0.5)
            ax.add_feature(cfeature.LAND)

        heatmap = ax.pcolormesh(data[lon_name], data[lat_name], data.isel(time=0),
                                cmap=cmap, transform=ccrs.PlateCarree())

        # Set axis limits to match data
        ax.set_xlim([data[lon_name].min(), data[lon_name].max()])
        ax.set_ylim([data[lat_name].min(), data[lat_name].max()])

        # Align colorbar width with heatmap
        pos = ax.get_position()
        cbar_ax = fig.add_axes([pos.x0, pos.y0 - 0.05, pos.width, 0.02])
        cbar = fig.colorbar(heatmap, cax=cbar_ax, orientation='horizontal')
        cbar.set_label(colorbar_label)
        ax.set_title(title)

        mesh = ax.pcolormesh(data[lon_name], data[lat_name], data.isel(time=0),
                             cmap=cmap, transform=ccrs.PlateCarree())

        def update(frame):
            time_str = np.datetime_as_string(data.time[frame].values, unit='h')
            mesh.set_array(data.isel(time=frame).values.flatten())
            ax.set_title(f"{title} - {time_str}")
            return mesh,

        interval = 1000 / fps  # Convert FPS to milliseconds
        animation = FuncAnimation(fig, update, frames=len(data.time), interval=interval, blit=True)

        plt.close()
        return animation, HTML(animation.to_html5_video())

In [20]:
data = hr_data['v10'].load()

In [None]:
animation, video = Visualizer.generate_animation(hr_data['v10'].load(), cmap="Blues", show_coastlines=True, fps=20)

In [15]:
video

In [13]:
animation.save('v10_20fps.mp4', writer='ffmpeg', fps=20)

## Transform *.mp4 to *.gif

In [None]:
from moviepy import *
import os

directory_path = '/home/ubuntu/project/destine-super-resolution/'

# List all MP4 files in the directory
mp4_files = [f for f in os.listdir(directory_path) if f.endswith('.mp4')]
print("MP4 files in the directory:")

for mp4_file in mp4_files:
    # Load the video file
    video = VideoFileClip(mp4_file)

    # Define the output GIF file path
    gif_file = mp4_file.replace('.mp4', '.gif')

    # Write the video to a GIF file
    video.write_gif(gif_file, fps=30)  # You can adjust the fps as needed