## CNN inference example in Pytorch

In [None]:
import xarray as xr
import matplotlib.pyplot as plt

import pycomlink as pycml
from pycomlink.processing.pytorch_util import run_inference

In [None]:
# Load data
data_path = pycml.io.examples.get_example_data_path()
cmls = xr.open_dataset(data_path + '/example_cml_data.nc')

# select 3 different CMLs to study
cmls = cmls.isel(cml_id = [0, 10, 370])

# Remove outliers, compute tl and interpolate missing values
cmls['tsl'] = cmls.tsl.where(cmls.tsl != 255.0)
cmls['rsl'] = cmls.rsl.where(cmls.rsl != -99.9)
cmls['tl'] = cmls.tsl - cmls.rsl # calculate total loss (previous TRSL)
cmls['tl'] = cmls.tl.interpolate_na(dim='time', method='linear', max_gap='5min')

In [None]:
# Dataarray shape is expected to be (time, channels, cml_id)
tl = cmls.tl.transpose('time', 'channel_id', 'cml_id')

# Normalisation
tl_normed = tl - tl.median(dim='time')

# Standardization should be part of the specific model, because it is trained with specific preprocessing

### Option 1: Load DL model from local .pt file


In [None]:
# Loading from file could support .pt and optional config.yml 
# also with hardcoded reflength: 60 if the config isn't available

# Set up your own path:
weights_path = 'C:/Users/lukas/.cml_wd_pytorch/models/best_model_jit.pt'
result = run_inference.cnn_wd(model_path_or_url=weights_path, data=tl_normed)

### Option 2: Load the model from URL and cahe it

In [None]:
# TODO: Include class for example polz_2025_cnn, which loads specific model from url, without forcing the user to provide string
# the user should only choose a model by a name and the rest of the complexity is hidden
# alternativelly: string can be path, url, CodeName for hardcoded url pointing to specific model, or NaN to load default model

# config should be optional, if its not provided, set reflength: 60  # Length of the radar data to consider for rainfall rate or wet label calculation
#model_url = "https://github ... .pt"
#config_url_or_path = ...       # if none, reflength will be set to 60
#result = pytorch_util.cnn_wd(model_url, config_url_or_path, data=tl_normed)

model_URL = "https://github.com/jpolz/cml_wd_pytorch/raw/be2b15fa987838ea1f709dd0180917eebf66271a/data/dummy_model/best_model_jit.pt"
result = run_inference.cnn_wd(model_path_or_url=model_URL,data=tl_normed)

In [None]:
for cml_id in cmls.cml_id.data:
    fig, axs = plt.subplots(3, 1, figsize=(14, 6))
    result.TL.sel(cml_id=cml_id).isel(channel_id=0).plot.line(x='time', ax=axs[0])
    result.predictions.sel(cml_id=cml_id).plot.line(x='time', ax=axs[1])
    (result.predictions.sel(cml_id=cml_id)>0.8).plot.line(x='time', ax=axs[2])
    
