In [1]:
#convert

# babilim.core.statefull_object

> An object which can have a state that is trainable or checkpointable. The core unit of babilim.

In [3]:
#export
from typing import Sequence, Any, Sequence, Callable, Dict, Iterable
from collections import defaultdict, OrderedDict
import babilim
from babilim import PYTORCH_BACKEND, TF_BACKEND
from babilim.core.checkpoint import load_state, save_state
from babilim.core.itensor import ITensor
from babilim.core.tensor import Tensor, TensorWrapper
from babilim.core.logging import info, warn, DEBUG_VERBOSITY

In [4]:
#export
_statefull_object_name_table = {}

In [None]:
#export
class StatefullObject(object):
    def __init__(self):
        """
        A statefull object is an object with variables that can be trainable and checkpointable.
        """
        self._wrapper = TensorWrapper()
        self._training = True

    @property
    def training(self) -> bool:
        """
        Property if the object is in training mode.
        
        ```python
        statefull_object.training
        ```
        
        :return: True if the object is in training mode.
        """
        return self._training

    @property
    def variables(self):
        """
        Property with all variables of the object.
        
        ```python
        statefull_object.variables
        ```
        
        :return: A list of the variables in this object.
        """
        return list(self.named_variables.values())

    @property
    def named_variables(self):
        """
        Property with all variables of the object.
        
        ```python
        statefull_object.named_variables
        ```
        
        :return: A dictionary of the variables in this object.
        """
        return dict(self.__variables_with_namespace())

    def __variables_with_namespace(self, namespace=""):
        all_vars = []
        extra_vars = []
        for member_name in self.__dict__:
            v = self.__dict__[member_name]
            if isinstance(v, str):
                pass
            elif isinstance(v, Dict):
                for i, (k, x) in enumerate(v.items()):
                    if not isinstance(k, str):
                        k = "{}".format(i)
                    name = namespace + "/" + member_name + "/" + k
                    if isinstance(x, StatefullObject):
                        all_vars.extend(x.__variables_with_namespace(name))
                    if isinstance(x, ITensor):
                        all_vars.append((name, x))
                    if self._wrapper.is_variable(x):
                        all_vars.append((name, self._wrapper.wrap_variable(x, name=name)))
                    if isinstance(x, object):
                        extra_vars.extend(self._wrapper.vars_from_object(v, name))
            elif isinstance(v, Iterable):
                for i, x in enumerate(v):
                    name = namespace + "/" + member_name + "/{}".format(i)
                    if isinstance(x, StatefullObject):
                        all_vars.extend(x.__variables_with_namespace(name))
                    if isinstance(x, ITensor):
                        all_vars.append((name, x))
                    if self._wrapper.is_variable(x):
                        all_vars.append((name, self._wrapper.wrap_variable(x, name=name)))
                    if isinstance(x, object):
                        extra_vars.extend(self._wrapper.vars_from_object(v, name))
            elif isinstance(v, StatefullObject):
                name = namespace + "/" + member_name
                all_vars.extend(v.__variables_with_namespace(name))
            elif isinstance(v, ITensor):
                name = namespace + "/" + member_name
                all_vars.append((name, v))
            elif self._wrapper.is_variable(v):
                name = namespace + "/" + member_name
                all_vars.append((name, self._wrapper.wrap_variable(v, name=name)))
            elif isinstance(v, object):
                name = namespace + "/" + member_name
                extra_vars.extend(self._wrapper.vars_from_object(v, name))
                for x in getattr(v, '__dict__', {}):
                    name = namespace + "/" + member_name + "/" + x
                    if isinstance(v.__dict__[x], StatefullObject):
                        all_vars.extend(v.__dict__[x].__variables_with_namespace(name))
                    if isinstance(v.__dict__[x], ITensor):
                        all_vars.append((name, v.__dict__[x]))
                    if self._wrapper.is_variable(v.__dict__[x]):
                        extra_vars.append((name, self._wrapper.wrap_variable(v.__dict__[x], name=name)))
        if len(all_vars) == 0:
            all_vars.extend(extra_vars)
        return all_vars

    @property
    def trainable_variables(self):
        """
        Property with trainable variables of the object.
        
        ```python
        statefull_object.trainable_variables
        ```
        
        :return: A list of the trainable variables in this object.
        """
        all_vars = self.variables
        train_vars = []
        for v in all_vars:
            if v.trainable:
                train_vars.append(v)
        return train_vars

    @property
    def named_trainable_variables(self):
        """
        Property with trainable variables of the object.
        
        ```python
        statefull_object.named_trainable_variables
        ```
        
        :return: A dictionary of the trainable variables in this object.
        """
        all_vars = self.named_variables
        train_vars = []
        for k, v in all_vars.items():
            if v.trainable:
                train_vars.append((k, v))
        return dict(train_vars)

    @property
    def untrainable_variables(self):
        """
        Property with not trainable variables of the object.
        
        ```python
        statefull_object.untrainable_variables
        ```
        
        :return: A list of not trainable variables in this object.
        """
        all_vars = self.variables
        train_vars = []
        for v in all_vars:
            if not v.trainable:
                train_vars.append(v)
        return train_vars

    @property
    def named_untrainable_variables(self):
        """
        Property with not trainable variables of the object.
        
        ```python
        statefull_object.named_untrainable_variables
        ```
        
        :return: A dictionary of not trainable variables in this object.
        """
        all_vars = self.named_variables
        train_vars = []
        for k, v in all_vars.items():
            if not v.trainable:
                train_vars.append((k, v))
        return dict(train_vars)

    @property
    def trainable_variables_native(self):
        """
        Property with not trainable variables of the object in native format.
        
        ```python
        statefull_object.trainable_variables_native
        ```
        
        :return: A list of trainable variables in this object in native format.
        """
        all_vars = self.trainable_variables
        train_vars = []
        for v in all_vars:
            train_vars.append(v.native)
        return train_vars

    @property
    def _parameters(self) -> OrderedDict:
        params = OrderedDict()
        params.update(self.named_trainable_variables)
        return params

    @property
    def _buffers(self) -> OrderedDict:
        params = OrderedDict()
        params.update(self.named_untrainable_variables)
        return params

    def state_dict(self) -> Dict:
        """
        Get the state of the object as a state dict (usable for checkpoints).
        
        :return: A dictionary containing the state of the object.
        """
        state = {}
        for name, var in self.named_variables.items():
            if babilim.is_backend(babilim.TF_BACKEND):
                state[name] = var.numpy()
            else:
                state[name] = var.numpy().T
        return state

    def load_state_dict(self, state_dict: Dict) -> None:
        """
        Load the state of the object from a state dict.
        
        Handy when loading checkpoints.
        
        :param state_dict: A dictionary containing the state of the object.
        """
        for name, var in self.named_variables.items():
            if name in state_dict:
                if babilim.is_backend(babilim.TF_BACKEND):
                    var.assign(state_dict[name])
                else:
                    var.assign(state_dict[name].T)
                if DEBUG_VERBOSITY:
                    info("  Loaded: {}".format(name))
            else:
                warn("  Variable {} not in checkpoint.".format(name))

    def eval(self):
        """
        Set the object into eval mode.
        
        ```python
        self.train(False)
        ```
        """
        self.train(False)

    def train(self, mode=True):
        """
        Set the objects training mode.
        
        :param mode: (Optional) If the training mode is enabled or disabled. (default: True)
        """
        self._training = mode
        for member_name in self.__dict__:
            obj = self.__dict__[member_name]
            if isinstance(obj, Sequence):
                for x in obj:
                    train_fn = getattr(x, "train", None)
                    if callable(train_fn):
                        train_fn(mode)
            else:
                train_fn = getattr(obj, "train", None)
                if callable(train_fn):
                    train_fn(mode)

    def load(self, checkpoint_file_path: str) -> None:
        """
        Load the state of the object from a checkpoint.
        
        :param checkpoint_file_path: The path to the checkpoint storing the state dict.
        """
        checkpoint = load_state(checkpoint_file_path)
        if "model" in checkpoint:
            self.load_state_dict(checkpoint["model"])
        else:
            babilim.error("Could not find state in checkpoint.")

    def save(self, checkpoint_file_path: str) -> None:
        """
        Save the state of the object to a checkpoint.
        
        :param checkpoint_file_path: The path to the checkpoint storing the state dict.
        """
        save_state({"model": self.state_dict()}, checkpoint_file_path)