# Model Classes

> Model classes

In [None]:
#| default_exp utils.model_classes

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from abc import ABC, abstractmethod
from pathlib import Path
from collections.abc import Sequence
from typing import Callable, Generic, TypeVar, Union, Protocol
import xarray as xr

In [None]:
#| export
class _Model(Protocol):
  @abstractmethod
  def predict(self, x: xr.Dataset) -> xr.Dataset:
    pass
  
  @abstractmethod
  def contributions(self, x: xr.Dataset) -> xr.Dataset:
    pass

In [None]:
#| export
class BaseModel(ABC):
    """
    Abstract class for all models
    """
    @abstractmethod
    def __init__(
      self, 
      model_name: str, # Name used to identify the model
      model_kpi: str, # Key performance indicator output by the model predict
      model_path: str|Path, # Path to the model artifact
      model_loader: Callable[str|Path, _Model], # Function to load the model
      ):
        self.model_name: str = model_name
        self.model_kpi: str = model_kpi
        self.model_type: str = model_type
        self.model_path: Path = model_path if isinstance(model_path, Path) else Path(model_path)
        self._model = model_loader(model_path)

    @abstractmethod
    def predict(
        self, 
        x: xr.Dataset # Input data
        ) -> xr.Dataset: # Predicted target variable
        """
        Predict the target variable from the input data
        """
        pass
    
    @abstractmethod
    def contributions(
        self, 
        x: xr.Dataset # Input data
        ) -> xr.Dataset: # Contributions of the input data to the target variable
        """
        Get the contributions of the input data to the target variable
        """
        pass

In [None]:
show_doc(BaseModel.predict)

---

### BaseModel.predict

>      BaseModel.predict (x:xarray.core.dataset.Dataset)

*Predict the target variable from the input data*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| x | Dataset | Input data |
| **Returns** | **Dataset** | **Predicted target variable** |

In [None]:
show_doc(BaseModel.contributions)

---

### BaseModel.contributions

>      BaseModel.contributions (x:xarray.core.dataset.Dataset)

*Get the contributions of the input data to the target variable*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| x | Dataset | Input data |
| **Returns** | **Dataset** | **Contributions of the input data to the target variable** |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()