# Work with obsSpace

JEDI or GSI generates observation space diagnostic files, which contains original observation information, hofx (H(x), i.e. model counter-parts at the observataion locations), OMB (observation minus background), OMA as well as other information.

This notebook covers how to deal with JEDI diganostic files (which are also called output ioda files). We call them `jdiag` files.

## import packages and define variables

In [None]:
%%time 

# autoload external python modules if they changed
%load_ext autoreload
%autoreload 2

import os, sys
pyDAmonitor_ROOT=os.getenv("pyDAmonitor_ROOT")
if pyDAmonitor_ROOT is None:
    print("!!! pyDAmonitor_ROOT is NOT set. Run `source ush/load_pyDAmonitor.sh`")
else:
    print(f"pyDAmonitor_ROOT={pyDAmonitor_ROOT}\n")
sys.path.insert(0, pyDAmonitor_ROOT)
    
# import modules
import warnings
import math
import numpy as np
import uxarray as ux
import xarray as xr
import pandas as pd
import seaborn as sns  # seaborn handles NaN values automatically
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from DAmonitor import query_dataset, query_data, query_obj, to_dataframe

jdiag_file=f"{pyDAmonitor_ROOT}/data/samples/mpasjedi/jdiag_aircar_t133.nc"

## Use NetCDF4 to read jdiag files

In [None]:
dataset=Dataset(jdiag_file, mode='r')
query_dataset(dataset)

Now we see that jdiag (or output ioda file) has the group/variable structure.

This group/variable structure is not conveniet to use.    
So we define a `obsSpace` class under `DAmonitor` to help deal with jdiag file more easily.

## Use the `obsSpace` class to read jdiag files

In [None]:
from DAmonitor import obsSpace, fit_rate

In [None]:
# create a t133 object using the obsSpace class
t133=obsSpace(jdiag_file)

In [None]:
# check the t133 object
query_obj(t133)

In [None]:
# play with some object attributes
(
t133.filepath,
t133.nlocs,
)

In [None]:
# query dataset
query_dataset(t133.dataset)  # the original jdiag file dataset

In [None]:
# query data
query_data(t133.q)  # since the t133 object does not contain q observations, so it will display meta information

In [None]:
# query data
query_data(t133.t)

The above output shows that we reorganize the data structure based on the observation variable, i.e. `t`, `q`, `uv`, `bt`   
We can access values easily using popular Python class strucutre, i.e. `t133.t`, `t133.t.ObsValue`, etc

In [None]:
# print out array values
np.set_printoptions(threshold=500) # don't print out all array values
print(t133.t.ObsValue)

In [None]:
# this will geneate lots of output, so we comment out the following two line by default. Uncomment them to see the results
# np.set_printoptions(threshold=np.inf)  # print out all array values
# print(t133.t.longitude)

## Convert to Pandas DataFrame
Converting jdiag data into Pandas DataFrame brings up lots of benefits and we can use utilize lots of mature DataFrames capabilities.    
We can see this from the following cell.

In [None]:
df = to_dataframe(t133.t)
df

## Plot hisgrams

### example 1

In [None]:
plt.figure(figsize=(8, 5))
#sns.histplot(df["oman"], bins=50, kde=True, color="steelblue")
sns.histplot(t133.t.oman, bins=100, kde=True, color="steelblue")
plt.title("Histogram of oman")
plt.xlabel("oman values")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

### example 2

In [None]:
df_long = df[["oman", "ombg"]].melt(var_name="variable", value_name="value")

plt.figure(figsize=(8, 5))
sns.histplot(data=df_long, x="value", hue="variable", bins=50, kde=True, element="step", stat="count")
plt.title("Overlayed Histogram: oman vs ombg")
plt.tight_layout()
plt.show()

### example 3

In [None]:
plt.figure(figsize=(8, 5))
sns.histplot(df["oman"], bins=100, kde=True, color="blue", label="oman", multiple="layer")
sns.histplot(df["ombg"], bins=100, kde=True, color="red", label="ombg", multiple="layer")

plt.title("Overlayed Histogram: oman vs ombg")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.show()

## plot fitting rate
The rate fitting to observtions is an important metric to evaluate data assimilation performance.    
We don't want a small fitting rate, which means very little observation impacts on the model forecast.   
We don't want a large fitting rate either, which means we fit too close to the obervations and the model forecast will crash.
Ususally, a fitting rate of 20%~30% is expected.

In [None]:
# Filter valid data (both 'oman' and 'ombg' are not NaN)
valid_df = df[df["oman"].notna() & df["ombg"].notna()].copy()
valid_df = valid_df.dropna(subset=["height"])  # removes any rows in valid_df where height is missing (NaN)
# print(valid_df[valid_df["height"] < 0]["height"])   # negative height

In [None]:
dz = 1000
grouped = fit_rate(t133.t, dz=dz)

# 5. Plot vertical profile of fit_rate vs height
plt.figure(figsize=(7, 6))
plt.plot(grouped["fit_rate"], grouped["height_bin"], marker="o", color="blue")
# plt.axvline(x=0, color="gray", linestyle="--")  # ax vertical line

plt.xlabel("Fit Rate (%)")  # label change
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x*100:.0f}%'))  # format as %
plt.ylabel("Height Bin (m)")
plt.title("Vertical Profile of Fit Rate")

# Fine-tune ticks
plt.xticks(np.arange(0, 0.25, 0.05))  #, fontsize=12)
plt.yticks(np.arange(0, 13000, dz))  #, , fontsize=12)
# Add minor ticks
from matplotlib.ticker import AutoMinorLocator
plt.gca().xaxis.set_minor_locator(AutoMinorLocator())
plt.gca().yaxis.set_minor_locator(AutoMinorLocator())
# plt.grid(which='both', linestyle='--', linewidth=0.5)
plt.grid(True)

plt.ylim(0, 13000)  # set y-axis from 0 (bottom) to 13,000 (top)
plt.tight_layout()
plt.show()

In [None]:
print(grouped["height_bin"])

## plot satellite observations

### load satellite data using `obsSpace`

In [None]:
%%time

cris_file = f"{pyDAmonitor_ROOT}/data/samples/mpasjedi/jdiag_cris-fsr_n20.nc"
obsCris = obsSpace(cris_file)

In [None]:
query_data(obsCris.bt, meta_exclude="sensorCentralWavenumber_") # removing the meta_exclude paramter will output all sensorCentralWavenumber_* information

In [None]:
print(obsCris.bt.hofx0)

### assemble target obs array

In [None]:
ncount=0
idx = []
idx2 = []
ch=61
for n in np.arange(len(obsCris.bt.ombg[:,ch])):
    #if obsCris.bt.CloudDetectMinResidualIR[n,ch] == 1: 
     if obsCris.bt.ombg[n,ch] > -200 and obsCris.bt.ombg[n,ch] < 200:
       idx.append(n)
       ncount = ncount + 1 

lat=obsCris.bt.latitude[idx]
lon=obsCris.bt.longitude[idx]
obarray=obsCris.bt.DerivedObsValue[idx,ch]
print(lon,lat,obarray)
print(ncount)

### prepare coloar map

In [None]:
datmi = np.nanmin(obarray)  # Min of the data
datma = np.nanmax(obarray)  # Max of the data

if np.nanmin(obarray) < 0:
  cmax = datma
  cmin = datmi
  cmax=310
  cmin=200
  #cmax=1.0
  #cmin=-1.0
  cmap = 'RdBu'
else:
  #cmax = omean+stdev
  #cmin = np.maximum(omean-stdev, 0.0)
  cma = datma
  cmin = datmi
  cmax=310
  cmin=200
  #cmax=1.0
  #cmin=-1.0
  cmap = 'RdBu'
  cmap = 'viridis'
  cmap = 'jet'


cmin = 200.
cmax = 310.
conus_12km = [-150, -50, 15, 55]

color_map = plt.cm.get_cmap(cmap)
reversed_color_map = color_map.reversed()
units = 'K'
#units = '%'
fig = plt.figure(figsize=(10, 5))

### make the plot

In [None]:
import cartopy.crs as ccrs
import matplotlib.ticker as mticker
import matplotlib.colors as colors

ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=0))

# Plot grid lines
# ----------------
gl = ax.gridlines(crs=ccrs.PlateCarree(central_longitude=0), draw_labels=True,
                  linewidth=1, color='gray', alpha=0.5, linestyle='-')
gl.top_labels = False
gl.xlabel_style = {'size': 10, 'color': 'black'}
gl.ylabel_style = {'size': 10, 'color': 'black'}
gl.xlocator = mticker.FixedLocator(
   [-180, -135, -90, -45, 0, 45, 90, 135, 179.9])
ax.set_ylabel("Latitude",  fontsize=7)
ax.set_xlabel("Longitude", fontsize=7)

# Get scatter data
# ------------------
sc = ax.scatter(lon, lat,
                c=obarray, s=4, linewidth=0,
                transform=ccrs.PlateCarree(), cmap=cmap, vmin=cmin, vmax = cmax, norm=None, antialiased=True)

# Plot colorbar
# --------------
cbar = plt.colorbar(sc, ax=ax, orientation="horizontal", pad=.1, fraction=0.06,ticks=[200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310])
#cbar = plt.colorbar(sc, ax=ax, orientation="horizontal", pad=.1, fraction=0.06,ticks=[-3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1.0, 1.5, 2.0, 2.5, 3 ])
#cbar = plt.colorbar(sc, ax=ax, orientation="horizontal", pad=.1, fraction=0.06,ticks=[0, 10, 20, 20, 40, 50, 60, 70, 80, 90, 100])
cbar.ax.set_ylabel(units, fontsize=10)
# Plot globally
# --------------
#ax.set_global()
#ax.set_extent(conus)
ax.set_extent(conus_12km)

# Draw coastlines
# ----------------
ax.coastlines()
ax.text(0.45, -0.1, 'Longitude', transform=ax.transAxes, ha='left')
ax.text(-0.08, 0.4, 'Latitude', transform=ax.transAxes,
        rotation='vertical', va='bottom')

#text = f"Total Count:{datcont:0.0f}, Max/Min/Mean/Std: {datma:0.3f}/{datmi:0.3f}/{omean:0.3f}/{stdev:0.3f} {units}"
#print(text)
#ax.text(0.67, -0.1, text, transform=ax.transAxes, va='bottom', fontsize=6.2)

dpi=100
gl.top_labels = False
plt.tight_layout()

# show plot
# -----------
# pname='test.png'
# plt.savefig(pname, dpi=dpi)