# Import Required Libraries
Import necessary libraries such as matplotlib and pollution_extraction modules.

In [None]:
import matplotlib.pyplot as plt

from pollution_extraction.core import DataExporter
from pollution_extraction.core.data_reader import PollutionDataReader
from pollution_extraction.core.data_visualizer import DataVisualizer
from pollution_extraction.core.spatial_extractor import SpatialExtractor

# Initialize PollutionDataReader
Initialize PollutionDataReader with the NetCDF file path and pollution type.

In [None]:
file_path = "/workspaces/dss-pollution-extraction/PM2p5_downscaled_daily_lr_2006_01.nc"
reader = PollutionDataReader(file_path, pollution_type="pm25")

# Inspect Dataset Information
Print dataset shape, basic information, time range, and spatial bounds.

In [None]:
data = reader.data_variable
print("\nData variable shape:", data.shape)
info = reader.get_basic_info()
print("\nBasic Info:")
for k, v in info.items():
    print(f"  {k}: {v}")
print("\nTime range:", reader.time_range)
print("Spatial bounds:", reader.spatial_bounds)

# Subset Data by Time
Select a subset of data for the first 7 days and print its shape.

In [None]:
subset = data.isel(time=slice(0, 7))
print("\nSubset shape (first 7 days):", subset.shape)

# Visualize First Time Slice
Plot the first time slice using xarray or matplotlib with controlled vmin/vmax.

In [None]:
first_slice = data.isel(time=0).clip(min=0)
print("\nFirst slice stats:")
print("  min:", float(first_slice.min().values))
print("  max:", float(first_slice.max().values))
print("  mean:", float(first_slice.mean().values))
try:
    first_slice.plot.imshow(vmin=0, vmax=40, cmap="Reds", origin="upper")
    plt.title("First Time Slice (time=0)")
    plt.show()
except Exception:
    plt.imshow(first_slice.values, origin="upper", vmin=0, vmax=40, cmap="Reds")
    plt.title("First Time Slice (time=0) [imshow fallback]")
    plt.colorbar()
    plt.show()

# Compute Monthly Average
Calculate the monthly average (mean over time) and visualize it.

In [None]:
dataset = reader.dataset
var_name = reader.variable_info["var_name"]
time_avg = dataset[var_name].mean(dim="time").clip(min=0)
print("\nTime-averaged (monthly mean) shape:", time_avg.shape)
time_avg.plot.imshow(vmin=0, vmax=40, cmap="RdYlBu_r", origin="upper")
plt.title("Monthly Mean (Time-Averaged) PM2.5")
plt.show()

# Extract Spatial Point
Extract the value at the center of the spatial domain using SpatialExtractor.

In [None]:
spatial_ext = SpatialExtractor(dataset, var_name)
x_center = float(
    reader.spatial_bounds["x_min"]
    + (reader.spatial_bounds["x_max"] - reader.spatial_bounds["x_min"]) / 2
)
y_center = float(
    reader.spatial_bounds["y_min"]
    + (reader.spatial_bounds["y_max"] - reader.spatial_bounds["y_min"]) / 2
)
try:
    point_result = spatial_ext.extract_points([(x_center, y_center)], method="nearest")
    print(
        f"\nExtracted value at domain center (x={x_center:.1f}, y={y_center:.1f}):\n",
        point_result,
    )
except KeyError as e:
    print(
        f"\n[SpatialExtractor] Extraction failed: {e}\nCheck if the coordinates are within the valid range and match the dataset's CRS."
    )

# Export Data to GeoTIFF
Demonstrate exporting the time-averaged map to GeoTIFF using DataExporter.

In [None]:
exporter = DataExporter(dataset, var_name)
exporter.to_geotiff(
    "/workspaces/dss-pollution-extraction/monthly_mean_pm25.tif",
    time_index=slice(None),
    aggregation_method="mean",
)
print("\n[DataExporter] Example: exporter.to_geotiff() can export data.")

# Custom Visualization
Use DataVisualizer to create a custom plot for a specific time index.

# Animate Time Slices
Create an animation to visualize the evolution of PM2.5 over time using matplotlib's animation module.

In [None]:
visualizer = DataVisualizer(dataset, var_name, reader.pollution_type)
fig = visualizer.plot_spatial_map(
    time_index=2,
    vmin=0,
    vmax=150,
    title="PM2.5 Day 1 (Visualizer)",
    cmap="RdYlBu_r",
    # origin="upper"
)
plt.show()
reader.close()

In [None]:
import numpy as np
from matplotlib import animation
from IPython.display import HTML

# Prepare data for animation
var_data = dataset[var_name].clip(min=0)
num_times = var_data.shape[0]

fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(var_data.isel(time=0).values, vmin=0, vmax=40, cmap="RdYlBu_r", origin="upper")
plt.title("PM2.5 Animation")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("PM2.5 (ug/m3)")

def animate(i):
    im.set_array(var_data.isel(time=i).values)
    ax.set_title(f"PM2.5 Day {i+1}")
    return [im]

ani = animation.FuncAnimation(fig, animate, frames=num_times, interval=400, blit=True)
plt.close(fig)  # Prevent duplicate static image
HTML(ani.to_jshtml())

In [None]:
# In Jupyter: Kernel -> Restart & Run All
# Then reimport:
from data_visualizer import DataVisualizer

dataset = reader.dataset
var_name = reader.variable_info["var_name"]
visualizer = DataVisualizer(dataset, var_name, pollution_type="pm25")
html_animation = visualizer.create_jupyter_animation(clip_min=0)
html_animation


In [None]:
# More control over parameters
html_animation = visualizer.create_animation(
    figsize=(8, 6),
    vmin=0,
    vmax=50,
    interval=300,
    return_html=True,
    title_template="{var_title} - {date}",
    clip_min=0,
)


In [None]:
# Save as GIF
visualizer.create_animation(output_path="pollution_animation.gif", fps=2, dpi=150)


In [None]:
# Use the enhanced create_animation method
visualizer = DataVisualizer(dataset, var_name, pollution_type="pm25")
html_animation = visualizer.create_animation(
    output_path=None,
    figsize=(6, 5),
    vmin=0,
    vmax=40,
    interval=400,
    return_html=True,
    origin="upper",
    clip_min=0,
    title_template="PM2.5 Day {frame}",
)
html_animation
