In [2]:
import xarray as xr
import os
from datetime import timedelta
import matplotlib.pyplot as plt

In [3]:
path = '//Users/anka/Desktop/research/msc25/data'

In [4]:
#load cesm std res model

file = 'cesm_StdModel/b.e11.B20TRLENS_RCP85.f09_g16.xaer.002.cam.h0.CLOUD.192001-200512.nc'

lens = xr.open_dataset(os.path.join(path, file))
lens['time'] = lens.time.get_index('time') - timedelta(days=16) # correct time in the file, it should starts from Jan
# ds = ds.convert_calendar('standard')

In [None]:
#select low clouds, create a month and yr coord, and select 1940 - 2005

lensCloud = lens.CLOUD.where(lens.CLOUD.lev >= 800, drop=True)
lensCloud = lensCloud.sum(dim='lev')
lensCloud = lensCloud.assign_coords(month=("time", lensCloud['time'].dt.month.data))
lensCloud = lensCloud.assign_coords(year=("time", lensCloud['time'].dt.year.data))
lensCloud = lensCloud.where(lensCloud.year >= 1940, drop=True)
lensCloud

In [None]:
#load era5 obvs

file = 'ERA5/data_stream-moda_stepType-avgua.nc'

era5 = xr.open_dataset(os.path.join(path, file))
era5

In [None]:
#create a month and yr coord, and select 1940 - 2005, rename coor for consistancy, reverse
era5Cloud = era5.lcc.where(era5.lcc.valid_time.dt.year <= 2005, drop = True)
era5Cloud = era5Cloud.assign_coords(month=("valid_time", era5Cloud.valid_time.dt.month.data))
era5Cloud = era5Cloud.assign_coords(year=("valid_time", era5Cloud.valid_time.dt.year.data))
era5Cloud = era5Cloud.rename({'valid_time': 'time'})
era5Cloud = era5Cloud.rename({'latitude': 'lat'})
era5Cloud = era5Cloud.rename({'longitude': 'lon'})

era5Cloud

In [8]:
def select_season(data, months, season_name):
    # Filter by month
    season_data = data.where(data.month.isin(months), drop=True)
    
    # For DJF, we need to handle the year roll
    if season_name == 'DJF':
        # Drop December of the last year to avoid incomplete season
        season_data = season_data.sel(time=season_data['time'].dt.month != 12) if season_data.time[-1].dt.month == 12 else season_data
        
    return season_data

In [None]:
lensCloud.mean('time').plot()

In [None]:
era5Cloud.mean('time').plot()
era5Cloud

In [None]:
djf = select_season(era5Cloud, [12, 1, 2], "DJF").mean("time")
mam = select_season(era5Cloud, [3, 4, 5], "MAM").mean("time")
jja = select_season(era5Cloud, [6, 7, 8], "JJA").mean("time")
son = select_season(era5Cloud, [9, 10, 11], "SON").mean("time")
djf

In [None]:
djf.imshow(djf.values, origin='lower')

In [None]:
djf_data = djf.values
mam_data = mam.values
jja_data = jja.values
son_data = son.values

fig, axs = plt.subplots(2, 2, figsize=(12, 8), constrained_layout=True)
fig.suptitle("ERA5")

# Set shared color range (optional)
vmin = min(djf_data.min(), mam_data.min(), jja_data.min(), son_data.min())
vmax = max(djf_data.max(), mam_data.max(), jja_data.max(), son_data.max())

# Plot manually
im0 = axs[0, 0].imshow(djf_data, origin='lower', vmin=vmin, vmax=vmax)
axs[0, 0].set_title("DJF")

im1 = axs[0, 1].imshow(mam_data, origin='lower', vmin=vmin, vmax=vmax)
axs[0, 1].set_title("MAM")

im2 = axs[1, 0].imshow(jja_data, origin='lower', vmin=vmin, vmax=vmax)
axs[1, 0].set_title("JJA")

im3 = axs[1, 1].imshow(son_data, origin='lower', vmin=vmin, vmax=vmax)
axs[1, 1].set_title("SON")

fig.colorbar(im0, ax=axs, orientation='vertical', fraction=0.03, pad=0.02, label='Cloud (mean)')

plt.show()