In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Callable

In [None]:
import glob
from datetime import datetime,date
import os
import sys
import gc
import io
import json
import re
import einops
from copy import deepcopy
from pprint import pprint
from collections import OrderedDict

In [None]:
import xarray as xr

In [None]:
import omegaconf
from omegaconf import OmegaConf
import comet_ml

In [None]:

import pytorch_lightning as pl






In [None]:
import seaborn as sns
from matplotlib.colors import LogNorm, Normalize
sns.set_theme(style="ticks")

In [None]:
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from matplotlib import colors, ticker
from matplotlib.colors import LogNorm

In [None]:
import cartopy.crs as ccrs

In [None]:
from bokeh import palettes

In [None]:
import holoviews as hv

In [None]:
import hvplot.pandas  # noqa

In [None]:
hv.extension('bokeh')

In [None]:
import sys
sys.path.append("../")

In [None]:
from data.data_module import VerticalCloudDataModule, LogTransform,LogTransform2D, get_variable_stat

from data.data_utils import load_patches, sort_overpass_indices, get_overpass_direction

from helpers.comet_helpers import get_patch_ids, load_experiment, get_hparams, get_trained_model, get_trained_experiment, get_dm, get_data_module, get_asset_id, dm_hparams_config, model_hparams_config, dm_overwrite_hparams, model_overwrite_hparams
from helpers.callbacks import *

from model.losses import *

from model.LightningModel import VerticalCloudCubeModel

In [None]:
from data.data_utils import get_horizontal_cloud_coverage, horizontal_cover_by_level, UnNormalize, get_height_level_range
from helpers.misc_helpers import  nested_set
from helpers.comet_helpers import log_image_to_comet

In [None]:
from evaluate.eval_plots import *

In [None]:
from model.DiscriminatorModel import IceCloudNetDisc

In [None]:
from inference.inference import load_discriminator_model, load_model

***

### Configs

In [None]:
y_hat_color = "#30A2DA"
dardar_color = "#FD654B"

hv.extension('matplotlib')

### User defined functions

In [None]:
def get_hparams(experiment, hparams_config,overwrite_params,print_hparams=False):
    hparams = dict()
    for param in experiment.get_parameters_summary():
        p_name = param["name"]
        if p_name not in hparams_config.keys():
            continue
        elif param["valueCurrent"] == "iwc":
            param_value = "iwc"
        elif param["valueCurrent"] in ["None",'null']:
            param_value = None
        elif hparams_config[p_name] in (list,tuple,dict):
            param_value_str = param["valueCurrent"]
            param_value = json.loads(param_value_str)
        elif hparams_config[p_name] == bool:
            if param["valueCurrent"] == "false":
                param_value = False
            else:
                param_value = True
        elif hparams_config[p_name] == datetime.date:
            param_value = datetime.strptime(param["valueCurrent"],"%Y-%m-%d").date()
        elif hparams_config[p_name] in [int,float,str,np.float64]:
            param_value = hparams_config[p_name](param["valueCurrent"])
        elif hparams_config[p_name] == LogTransform:
            args = ["constant", "scaler"]
            args_dict = dict()
            input_string = param["valueCurrent"]
            for arg in args:
                match = re.search(fr'{arg}=(\d+\.*)', input_string)
                # Check if the pattern was found
                if match:
                    # Extract the matched value
                    args_dict[arg] = np.float64(match.group(1))
            param_value=LogTransform(**args_dict)
    
        else:
            print(f"No value for {p_name} available/possible to parse → set manually")
            param_value = "SetManually"
        hparams[p_name] = param_value

    # set params to overwrite
    print(overwrite_params)
    for key,val in overwrite_params.items():
        print(f"Set {key} manually with value {val}")
        hparams[key]=val
        
    if print_hparams:
        print("-------")
        print("hparams")
        print("-------")
        pprint(hparams,depth=1)
    return hparams

def get_trained_model(experiment, model_hparams):
    vertical_cloud_model = VerticalCloudCubeModel(**model_hparams)
    exp_key = experiment.get_metadata()["experimentKey"]
    try:
        # load local
        path = glob.glob(os.path.join("/cluster/work/climate/kjeggle/model_checkpoints", exp_key, "*.ckpt"))[0]
        pretrained_state_dict = torch.load(path)["state_dict"]
        print("loaded model from disk")
    except IndexError:
        # load from comet
        asset_id = experiment.get_model_asset_list("ice_net")[0]["assetId"]
        model_binary = experiment.get_asset(asset_id, return_type="binary")
        pretrained_state_dict = torch.load(io.BytesIO(model_binary))
        print("loaded model from comet")
    vertical_cloud_model.load_state_dict(pretrained_state_dict)
    
    return vertical_cloud_model

def get_data_module(experiment, dm_hparams):
    return VerticalCloudDataModule(**dm_hparams)

def get_dm(experiment_id, dm_overwrite_params,dm_hparams_config=dm_hparams_config):
    experiment = load_experiment(experiment_name=experiment_id)
    print("loaded experiment from comet")
    
    dm_hparams = get_hparams(experiment, 
                             print_hparams=True,
                             hparams_config=dm_hparams_config,
                             overwrite_params=dm_overwrite_params)
    
    dm = VerticalCloudDataModule(**dm_hparams)
    dm.setup(stage="fit")
    
    return dm 

def get_trained_experiment(experiment_id,
                           model_overwrite_params,
                           model_hparams_config=model_hparams_config):

    experiment = load_experiment(experiment_name=experiment_id)
    print("loaded experiment from comet")
       
    model_hparams = get_hparams(experiment, 
                                hparams_config=model_hparams_config, 
                                overwrite_params=model_overwrite_params,
                                print_hparams=False)
    
    model = VerticalCloudCubeModel(**model_hparams)

    # load local
    exp_key = experiment.get_metadata()["experimentKey"]
    try:
        path = glob.glob(os.path.join("/cluster/work/climate/kjeggle/model_checkpoints", exp_key, "*.ckpt"))[0]
        pretrained_state_dict = torch.load(path)["state_dict"]
        print("loaded model from disk")
    except IndexError:
        # load from comet
        asset_id = experiment.get_model_asset_list("ice_net")[0]["assetId"]
        model_binary = experiment.get_asset(asset_id, return_type="binary")
        pretrained_state_dict = torch.load(io.BytesIO(model_binary))
        print("loaded model from comet")
    model.load_state_dict(pretrained_state_dict)
    
    return experiment, model

In [None]:
type(exp)

In [None]:
type(load_experiment("famous_elk_2849"))

#### Create Inference data (execute hidden cells)

In [None]:
def create_inference_data(model, dataloader,store_3d_data = False,load_meta_data = True,split="val",stop_at_step=10):

    seviri_data = []
    era5_data = []
    overpass_mask_data = []
    patch_idx_data = []
    dardar_cube_data = []
    y_hat_cube_data = []
    y_hat_profile_data_list = []
    dardar_profile_data_list = []
    y_hat_cloud_cover_data = []
    dardar_cloud_cover_data = []
    meta_data = []

    out_channels = model.out_channels
    
    # manually set level thickness for calculating horizontal cloud cover
    if out_channels >= 256:
        n_level_aggregation=16
    elif out_channels == 64:
        n_level_aggregation = 4
    elif out_channels == 55:
        n_level_aggregation = 5
    elif out_channels == 16:
        n_level_aggregation = 1
    else:
        raise ValueError(f"n_level_aggregation for out channels {n_level_aggregation}")
        
    if isinstance(model, IceCloudNetDisc):
        n_targets = model.unet.prediction_heads
    else:
        n_targets = model.prediction_heads
        
    
    for step, data in enumerate(dataloader):
        if step % 10 == 0: print(step)
        
        if isinstance(model, IceCloudNetDisc):
            seviri, era5, dardar, overpass_mask, md, patch_idx = model.get_input(data, split="val")
            seviri = seviri.float() # this means float32, double() is float64,
            overpass_mask = overpass_mask.long()
            dardar = dardar.float()
        else:
            seviri, era5, dardar, overpass_mask, md, patch_idx = data
        
        if torch.cuda.is_available():
            seviri = seviri.to(torch.device("cuda"))
            md = md.to(torch.device("cuda"))

        with torch.no_grad():
            if model.meta_data_embedding:
                y_hat = model(seviri, md)
            else:
                y_hat = model(seviri)

        y_hat = y_hat.cpu()
        dardar = dardar.cpu()

        # for each sample extract profile
        for idx in range(y_hat.shape[0]):
            y_hat_profile = torch.masked_select(y_hat[idx], overpass_mask[idx].bool())
            y_hat_profile = einops.rearrange(y_hat_profile, '(c z overpass) -> overpass c z',c=n_targets, z=out_channels).squeeze() # squeeze so shape is (overpass, z) for single target
            
            dardar_profile = torch.masked_select(dardar[idx], overpass_mask[idx].bool())
            dardar_profile = einops.rearrange(dardar_profile, '(c z overpass) -> overpass c z',c=n_targets, z=out_channels).squeeze()

            y_hat_profile_data_list.append(y_hat_profile)
            dardar_profile_data_list.append(dardar_profile)

            # calculate horizontal cloud cover per level
            if n_targets == 1:
                dardar_cloud_cover = horizontal_cover_by_level(dardar_profile,n_level_aggregation=n_level_aggregation)
                y_hat_cloud_cover = horizontal_cover_by_level(y_hat_profile,n_level_aggregation=n_level_aggregation)
            else:
                # use only single target
                dardar_cloud_cover = horizontal_cover_by_level(dardar_profile[:,0],n_level_aggregation=n_level_aggregation)
                y_hat_cloud_cover = horizontal_cover_by_level(y_hat_profile[:,0],n_level_aggregation=n_level_aggregation)
                
            dardar_cloud_cover_data.append(dardar_cloud_cover)
            y_hat_cloud_cover_data.append(y_hat_cloud_cover)        

        era5_data.append(era5.cpu())
        overpass_mask_data.append(overpass_mask.cpu())
        patch_idx_data.append(patch_idx.cpu())
        seviri_data.append(seviri.cpu())
        meta_data.append(md.cpu())

        if store_3d_data:
            dardar_cube_data.append(dardar.cpu())
            y_hat_cube_data.append(y_hat.cpu())

        # load meta_data → could be done in get_item directly to prevent loading each path twice
        # if load_meta_data:
            # for patch_id in patch_idx:
                # patch = xr.open_dataset(os.path.join(dataloader.dataset.root,dataloader.dataset.patch_ids[patch_id]))
                # meta_data["lat"].append(patch.latitude.mean().values.item())
                # meta_data["lon"].append(patch.longitude.mean().values.item())
                # meta_data["lwm"].append(stats.mode(patch.land_water_mask).mode.item())
                # meta_data["day_night_flag"].append(patch.day_night_flag.max().values.item())
                # meta_data["time_of_day"].append(patch.sensing_stop.dt.hour.values.item())
                # meta_data["month"].append(patch.sensing_stop.dt.month.values.item())
        
        if stop_at_step:
            if step==stop_at_step: break

    # concat input data
    overpass_mask_data = torch.concat(overpass_mask_data)
    patch_idx_data = torch.concat(patch_idx_data)
    seviri_data = torch.concat(seviri_data)
    era5_data = torch.concat(era5_data)
    meta_data = torch.concat(meta_data)

    # concat cloud cover
    dardar_cloud_cover_data = torch.concat(dardar_cloud_cover_data)
    y_hat_cloud_cover_data  = torch.concat(y_hat_cloud_cover_data)

    # if load_meta_data:
        # meta_data["day_night_flag"] = torch.tensor(meta_data["day_night_flag"])
        # meta_data["lwm"] = torch.tensor(meta_data["lwm"])
        # meta_data["lon"] = torch.tensor(meta_data["lon"])
        # meta_data["lat"] = torch.tensor(meta_data["lat"])

    # profile data, and cube data are in lists with length |samples|, other data is in concatenated tensors already
    eval_data = dict(overpass_mask_data=overpass_mask_data,
         patch_idx_data=patch_idx_data,
         seviri_data=seviri_data,
         era5_data=era5_data,
         dardar_profile_data_list=dardar_profile_data_list,
         y_hat_profile_data_list=y_hat_profile_data_list,
         dardar_cloud_cover_data=dardar_cloud_cover_data,
         y_hat_cloud_cover_data =y_hat_cloud_cover_data,
         y_hat_cube_data=y_hat_cube_data,
         dardar_cube_data=dardar_cube_data,
         meta_data=meta_data)
    
    return eval_data

def get_profile_data(eval_data, meta_data_filter=None,target_transform=LogTransform(scaler=1e7),selected_height_levels=None):
    """Returns concatenated y_hat, dardar data
    
    args:
        eval_data
        meta_data_filter (Tuple): (meta data variable index, value)
    
    """
    if meta_data_filter:
        filtered_idxs = np.where(np.array(eval_data["meta_data"][:,meta_data_filter[0]])==meta_data_filter[1])[0]

        y_hat_profile_data_list = [eval_data["y_hat_profile_data_list"][i] for i in filtered_idxs]
        dardar_profile_data_list = [eval_data["dardar_profile_data_list"][i] for i in filtered_idxs]
    else:
        y_hat_profile_data_list = eval_data["y_hat_profile_data_list"]
        dardar_profile_data_list = eval_data["dardar_profile_data_list"]

    # profile data
    dardar_profile_data = torch.concat(dardar_profile_data_list)
    y_hat_profile_data = torch.concat(y_hat_profile_data_list)
    # convert to original scale → this is what we mainly work with in the evaluation
    dardar = target_transform.inverse_transform(dardar_profile_data)
    y_hat = target_transform.inverse_transform(y_hat_profile_data)
    
    if selected_height_levels is not None:
        if len(y_hat.shape)==2:
            dardar = dardar[:, selected_height_levels]
            y_hat = y_hat[:, selected_height_levels]
        else:
            dardar = dardar[:,:, selected_height_levels]
            y_hat = y_hat[:,:, selected_height_levels]
    
    return y_hat, dardar

def get_cube_data(eval_data,target_transform=LogTransform(scaler=1e7)):
    """todo: add filter and dardar data"""
    y_hat_cube = torch.concat(eval_data["y_hat_cube_data"])
    y_hat_cube = target_transform.inverse_transform(y_hat_cube) # transform to original space
    # y_hat_cube_flat = einops.rearrange(y_hat_cube, 'b z y x -> (b y x) z ') # rearrange cube data to 2d: n, 256
    
    return y_hat_cube#, y_hat_cube_flat

def run_evaluation(y_hat, 
                   dardar, 
                   eval_data, 
                   exp, 
                   target_variable="iwc",
                   target_transform=LogTransform(scaler=1e7),
                   height_levels:np.ndarray=get_height_level_range(1680,16980,step=60),
                   cloud_thres=0,
                   log_image_kwargs=dict(),
                   suffix=""):
    """runs evaluation given inference/ground truth data and plots eval plots to comet
    
    todo: alternatively display plots inline
    
    """
    if suffix != "":
        suffix = f"_{suffix}"
    
    if target_variable=="icnc_5um":
        target_variable="nice"
    
    level_thickness = np.abs(height_levels[0]-height_levels[1])
       
    # cloud occurance per height level
    hv.extension("bokeh")
    df, occurance_p = cloud_occurance_per_height_level(y_hat, dardar, height_levels=height_levels)
    log_image(occurance_p, f"cloud_occurance{suffix}", **log_image_kwargs,exp_obj=exp)
    
    # metrics per level
    df, metric_p = metrics_per_level(y_hat, dardar, cloud_thres=cloud_thres, height_levels=height_levels, target_transform=target_transform,n_level_aggregation=int(height_levels.shape[0]/16)) # todo better heuristic for n_level_aggregation 
    log_image(metric_p,f"performance_metrics_levels{suffix}", **log_image_kwargs,exp_obj=exp)
    

    # iwc per height level
    y_hat_height_iwc_df, y_hat_height_iwc_plt = iwc_per_height_df(y_hat,color=y_hat_color,plt_q10=False, height_levels=height_levels,target_variable=target_variable)
    dardar_height_iwc_df, dardar_height_iwc_plt = iwc_per_height_df(dardar,color=dardar_color,plt_q10=False, height_levels=height_levels,target_variable=target_variable)
    p = (y_hat_height_iwc_plt * dardar_height_iwc_plt).opts(fontscale=1.5)
    log_image(p, f"iwc_height{suffix}", **log_image_kwargs,exp_obj=exp)
    
    # iwc vs iwc
    df, g = iwc_vs_iwc_plt(y_hat, dardar,target_variable=target_variable)
    log_image(g, f"iwc_vs_iwc{suffix}", **log_image_kwargs,exp_obj=exp)
    
    
    # todo: was displaying only the last plot, same problem for zonal mean count
    # for p in height_plts:
        # log_image_to_comet(exp, p,image_name=p.fig._suptitle.get_text())
    
    # metrics 
    metric_dict = get_metrics(y_hat, dardar, cloud_thres=cloud_thres, target_transform=target_transform)
    
    if suffix != "":
        metric_dict = {f"{key}{suffix}":value for key,value in metric_dict.items()}
    
    pprint(metric_dict)
    
    if suffix == "":
        # todo implement filter for day/night below
        # horizontal cloud cover
        all_data_plt, summary_line_plt, hor_cloud_cover_metrics = horizontal_cloud_cover(eval_data["dardar_cloud_cover_data"],eval_data["y_hat_cloud_cover_data"],height_levels=height_levels)
        log_image(all_data_plt, f"horizontal_cloud_cover{suffix}", **log_image_kwargs,exp_obj=exp)
        log_image(summary_line_plt, f"horizonal_cloud_cover_level_stats{suffix}", **log_image_kwargs,exp_obj=exp)
    
    if suffix in ["" ,"_nice"]:
        print("calc zonal/iwp means")
        height_levels=get_height_level_range(1680,16980,step=240)[:-1]
        # zonal means
        fig_zonal = zonal_mean(eval_data["y_hat_profile_data_list"],eval_data["dardar_profile_data_list"],eval_data["meta_data"][:,0],height_levels=height_levels,min_count=250,target_variable=target_variable,target_transform=target_transform)
        log_image(fig_zonal, f"zonal_mean{suffix}", **log_image_kwargs,exp_obj=exp)

        # iwp mean
        #fig,dardar_Z,y_hat_Z = iwp_regional_mean(eval_data["y_hat_profile_data_list"],eval_data["dardar_profile_data_list"],eval_data["meta_data"][:,0],eval_data["meta_data"][:,1],level_thickness=level_thickness,height_levels=height_levels,grid_size=4,target_variable=target_variable,target_transform=target_transform)
        #log_image(fig, f"iwp_mean{suffix}", **log_image_kwargs,exp_obj=exp)
    

def eval_pipeline(exp_config, 
                  eval_kwargs, 
                  meta_data_filter=None, 
                  dm_experiment_id="famous_elk_2849",
                  dm_overwrite_hparams=dm_overwrite_hparams,
                  log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False),
                  run_eval=True):

    experiment_id = exp_config["experiment_id"]
    print(f"start eval pipeline for {experiment_id}")
    
    dm = get_dm(dm_experiment_id, dm_overwrite_hparams)
    val_dataloader = dm.val_dataloader()
    print("loaded dm")
    
    if exp_config["type"]=="unet_only":
        exp, model = get_trained_experiment(experiment_id,exp_config["model_overwrite_hparams"])
    elif exp_config["type"]=="unet_disc":
        exp, model, conf = load_discriminator_model(experiment_id)
    else:
        raise ValueError

    model = model.to(torch.device("cuda"))
    model.eval()
    
    eval_data = create_inference_data(model, val_dataloader, **eval_kwargs)
    #todo implement loop for meta data filter
    y_hat, dardar = get_profile_data(eval_data, meta_data_filter=meta_data_filter,target_transform=dm.target_transform)
    y_hat_night, dardar_night = get_profile_data(eval_data, meta_data_filter=(3,1),target_transform=dm.target_transform)
    y_hat_day, dardar_day = get_profile_data(eval_data, meta_data_filter=(3,0),target_transform=dm.target_transform)
    
    if eval_kwargs["store_3d_data"]:
        y_hat_cube = get_cube_data(eval_data,target_transform=dm.target_transform)
    else:
        y_hat_cube = None#, None
    print("generated eval data")
    
    # define height levels for eval
    if dm.fold_to_level_thickness:
        level_thickness = 60*dm.fold_to_level_thickness
    else:
        level_thickness = 60
    height_levels = get_height_level_range(dm.height_levels[1],dm.height_levels[0],level_thickness)
    height_levels = height_levels[-y_hat.shape[1]:] # height levels has to have the same shape as predictions / cut off highest hight if necessary
    
    target_variable = "iwc" if dm.target_variable=="iwc" else "nice"
    
    # run evaluation plot pipeline
    if run_eval:
        try:
            run_evaluation(y_hat=y_hat,dardar=dardar,eval_data=eval_data,exp=exp,target_variable=target_variable,target_transform=dm.target_transform,height_levels=height_levels, log_image_kwargs=log_image_kwargs)
            print("saved plots to comet")
            run_evaluation(y_hat=y_hat_day,dardar=dardar_day,eval_data=eval_data,exp=exp,target_variable=target_variable,target_transform=dm.target_transform,height_levels=height_levels, log_image_kwargs=log_image_kwargs,suffix="day")
            print("saved day plots to comet")
            run_evaluation(y_hat=y_hat_night,dardar=dardar_night,eval_data=eval_data,exp=exp,target_variable=target_variable,target_transform=dm.target_transform,height_levels=height_levels, log_image_kwargs=log_image_kwargs,suffix="night")
            print("saved night plots to comet")
        except BaseException as ex:
            print(ex)
    
    return exp, dm, model, eval_data, y_hat, dardar, y_hat_cube

def log_image(p, 
              image_name, 
              display_inline=True,
              log_to_overleaf=False, 
              overleaf_dir="",
              log_to_comet=False, 
              exp_obj=None, 
              comet_log_kwargs=dict(overwrite=True)):
    if log_to_comet:
        assert exp_obj is not None, "provide comet experiment object"
        log_image_to_comet(exp_obj,p,image_name,log_kwargs=comet_log_kwargs)
        
    if log_to_overleaf:
        fpath = os.path.join(overleaf_dir, f"{image_name}.png")
        #hvplot
        if "holoviews" in str(type(p)):
            hv.save(p,fpath,dpi=600)
        # matplotlib
        else:
        # elif isinstance(p,plt.Figure):
            p.savefig(fpath,dpi=600)
            if isinstance(p,plt.Figure):
                plt.close(p)
            else:
                # seaborn
                plt.close(p.fig)
    if display_inline:    
        if isinstance(p, hv.core.overlay.Overlay) or isinstance(p,hv.core.overlay.NdOverlay):
            display(p)
        else:
            plt.show()


In [None]:
plt_kwargs_commons = {'var_range': 
                      {'iwc': (1e-07, 0.001), 
                       'nice': (1e2, 1e6)},
                      'var_range_log': 
                      {'iwc': (-7, -3), 
                       'nice': (2, 6)},
                      'axis_title':
                       {'iwc': r'IWC [kg m$^{-3}$]',
                        'nice': r'Nice [m$^{-3}$]'}}

In [None]:
pl.seed_everything(13)

### Get patch_ids

In [None]:
data_dir = "../helper_files"
with open(os.path.join(data_dir,"train_pids.json"), 'r') as file:
    train_patch_ids = json.load(file)
            
with open(os.path.join(data_dir,"val_pids.json"), 'r') as file:
    val_patch_ids = json.load(file)
    
with open(os.path.join(data_dir,"test_pids.json"), 'r') as file:
    val_patch_ids = json.load(file)

### Load model & data module

In [None]:
# todo set `data_dir` to path containing TrainingData
model, dm = load_model(data_dir="/net/n2o/wolke_scratch2/kjeggle/VerticalCloud/Nice128",
                   train_patch_ids=[],
                   val_patch_ids=val_patch_ids,
                   model_conf_filepath="../model_configs/ice_cloud_net_conf.yaml",
                   model_checkpoint_dir="../model_checkpoints")

### Run evaluation pipeline

#### Get eval data 

In [None]:
eval_kwargs = dict(store_3d_data=False,load_meta_data = True,split="val",stop_at_step=250)
meta_data_filter = {}

In [None]:
%%time
model.eval()

eval_data = create_inference_data(model, dm.val_dataloader(), **eval_kwargs)
#todo implement loop for meta data filter

In [None]:
overpass_length = torch.sum(eval_data["overpass_mask_data"],dim=(1,2))

In [None]:
overpass_length = overpass_length[overpass_length>256]

In [None]:
if dm.fold_to_level_thickness:
    level_thickness = 60*dm.fold_to_level_thickness
else:
    level_thickness = 60
height_levels = get_height_level_range(dm.height_levels[1],dm.height_levels[0],level_thickness)
height_levels = height_levels[-model.out_channels:] # height levels has to have the same shape as predictions / cut off highest hight if necessary

In [None]:
selected_height_levels = height_levels > 3800

In [None]:
y_hat_iwc, dardar_iwc = get_profile_data(eval_data, meta_data_filter=None,target_transform=LogTransform(scaler=1e7),selected_height_levels=selected_height_levels)

In [None]:
y_hat_iwc = y_hat_iwc[:,0]
dardar_iwc = dardar_iwc[:,0]

In [None]:
y_hat_night_iwc, dardar_night_iwc = get_profile_data(eval_data, meta_data_filter=(3,1),target_transform=LogTransform(scaler=1e7),selected_height_levels=selected_height_levels)
y_hat_day_iwc, dardar_day_iwc = get_profile_data(eval_data, meta_data_filter=(3,0),target_transform=LogTransform(scaler=1e7),selected_height_levels=selected_height_levels)

In [None]:
y_hat_night_iwc = y_hat_night_iwc[:,0]
dardar_night_iwc = dardar_night_iwc[:,0]

y_hat_day_iwc = y_hat_day_iwc[:,0]
dardar_day_iwc = dardar_day_iwc[:,0]

In [None]:
# get y_hat_cube n, z,y,x
if eval_kwargs["store_3d_data"]:
    y_hat_cube_iwc = torch.concat(eval_data["y_hat_cube_data"])
    y_hat_cube_iwc = dm.target_transform.inverse_transform(y_hat_cube_iwc)

In [None]:
nice_logtrans = LogTransform(scaler=1e-2)

In [None]:
y_hat_nice, dardar_nice = get_profile_data(eval_data, meta_data_filter=meta_data_filter,target_transform=nice_logtrans,selected_height_levels=selected_height_levels)

In [None]:
y_hat_nice = y_hat_nice[:,1]
dardar_nice = dardar_nice[:,1]

In [None]:
y_hat_night_nice, dardar_night_nice = get_profile_data(eval_data, meta_data_filter=(3,1),target_transform=nice_logtrans,selected_height_levels=selected_height_levels)
y_hat_day_nice, dardar_day_nice = get_profile_data(eval_data, meta_data_filter=(3,0),target_transform=nice_logtrans,selected_height_levels=selected_height_levels)

In [None]:
y_hat_night_nice = y_hat_night_nice[:,1]
dardar_night_nice = dardar_night_nice[:,1]

y_hat_day_nice = y_hat_day_nice[:,1]
dardar_day_nice = dardar_day_nice[:,1]

In [None]:
# get y_hat_cube n, z,y,x
if eval_kwargs["store_3d_data"]:
    y_hat_cube_nice = torch.concat(eval_data["y_hat_cube_data"])
    y_hat_cube_nice = dm.target_transform.inverse_transform(y_hat_cube_nice)

In [None]:
y_hat = y_hat_iwc
dardar = dardar_iwc

In [None]:
y_hat_night = y_hat_night_iwc
dardar_night = dardar_night_iwc

y_hat_day = y_hat_day_iwc
dardar_day = dardar_day_iwc

In [None]:
height_levels = height_levels[selected_height_levels]

In [None]:
cloud_percentage = (dardar>0).sum() / (dardar>=0).sum()
print(f"% of cloudy pixels {cloud_percentage:.2f}")

In [None]:
cloud_percentage = (dardar>0).sum() / (dardar>=0).sum()
print(f"% of cloudy pixels {cloud_percentage:.2f}")

### Run plotting

In [None]:
%matplotlib inline

In [None]:
# iwc 
run_evaluation(y_hat_iwc,dardar_iwc,eval_data,exp=None,height_levels=height_levels,target_transform=LogTransform(scaler=1e7),target_variable="iwc",log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False))

In [None]:
# iwc night
run_evaluation(y_hat_night_iwc,dardar_night_iwc,eval_data,exp=None,height_levels=height_levels,target_transform=LogTransform(scaler=1e7),target_variable="iwc",suffix="night",log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False))

In [None]:
# iwc day
run_evaluation(y_hat_day_iwc,dardar_day_iwc,eval_data,exp=None,height_levels=height_levels,target_transform=LogTransform(scaler=1e7),target_variable="iwc",suffix="day",log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False))

In [None]:
# icnc 
run_evaluation(y_hat_nice,dardar_nice,eval_data,exp=None,height_levels=height_levels,target_transform=nice_logtrans,target_variable="nice",suffix="nice",log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False))

In [None]:
# icnc day
run_evaluation(y_hat_night_nice,dardar_night_nice,eval_data,exp=None,height_levels=height_levels,target_transform=nice_logtrans,target_variable="nice",suffix="nice_night",log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False))

In [None]:
# icnc _night
run_evaluation(y_hat_day_nice,dardar_day_nice,eval_data,exp=None,height_levels=height_levels,target_transform=nice_logtrans,target_variable="nice",suffix="nice_day",log_image_kwargs=dict(display_inline=True, log_to_overleaf=False, log_to_comet=False))