In [None]:
#| default_exp encoder

In [None]:
#| hide
#%load_ext autoreload --> Not working TODO:REVISAR
# %autoreload 2 

In [1]:
#| export
import warnings
import math
from dvats.memory import *
import dvats.utils as ut
from dvats.config import show_attrdict
from copy import deepcopy
## -- Classes & types
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Callable, Union, Any

# Fastai
#| export
from fastai.learner import Learner
from tsai.data.core import TSDataLoaders
# Moirai
import uni2ts.model.moirai.module as moirai
import uni2ts.model.moirai.forecast as moirai_forecast



# Encoder

> Architectures and functions for creating encoders that create the embeddings

In [2]:
#| export
import pandas as pd
import numpy as np
from fastcore.all import *
from tsai.callback.MVP import *
from tsai.imports import *
from tsai.models.InceptionTimePlus import InceptionTimePlus
from tsai.models.explainability import get_acts_and_grads
from tsai.models.layers import *
from tsai.data.validation import combine_split_data
from tsai.basics import *
from fastai.callback.hook import hook_outputs
from momentfm import MOMENTPipeline
from gluonts.dataset.pandas import PandasDataset
from tsai.data.validation import TimeSplitter
from fastai.callback.wandb import WandbCallback
from fastai.callback.progress import ShowGraphCallback
from fastai.callback.schedule import *
from fastai.callback.tracker import EarlyStoppingCallback
from fastai.callback.tracker import SaveModelCallback
import time
import einops
import traceback

In [3]:
#| hide
from tsai.all import *

## Encoder class

In [None]:
#| export
@dataclass
class EncoderInput:
    # Data
    _data               : Union [ pd.DataFrame, List [ List [ List [ float ]]] ] = None
    _size               : int                               = None
    _shape              : Optional [ Tuple [ int, ... ] ]   = None
    _shapes             : List [ Tuple [ int, ...]]         = None
    stride              : int                               = None
    batch_size          : int                               = None
    _update_size        : bool                              = True
    _update_shape       : bool                              = True
    # Windows                   
    n_windows           : int                               = None
    n_windows_percent   : float                             = None
    validation_percent  : float                             = None
    training_percent    : float                             = None
    window_mask_percent : float                             = None
    # Time                  
    time_flag           : bool                              = None

    def __post_init__(self):
        self._update_size       = True
        self._update_shape      = True
        #Todo: check how to validate the input dataset allowing both windowed or not
        # --- Not working
        ###self._data              = ut._check_value(self.data, None, "_data", pd.DataFrame, allow_none = True )
        ###if self._data is None: 
        ###    self._data = ut._validate_nested_list(self._data, None, "_data", [float, int], 3, False, False, False)
        self.stride,_               = ut._check_value(self.stride, 1, "stride", int, positive = True)
        self.batch_size,_           = ut._check_value(self.batch_size, 32, "batch_size", int,  )
        self.validation_percent,_   = ut._check_value(self.validation_percent, 0.2, "validation_percent", percent = True)
        self.training_percent,_     = ut._check_value(self.training_percent, 0.2, "training_percent", percent = True)
        self.window_mask_percent,_  = ut._check_value(self.window_mask_percent, 0.3, "training_percent", percent = True)
        self.time_flag,_            = ut._check_value(self.time_flag, False, "time_flag", bool)

    @property
    def size(self):
        if self._data is not None and ( self._update_size or self._size is None or self._size == 0):
            self._size          = len(self._data)
            self._update_size   = False
            self._size,_ = ut._check_value(self._size, 0, "_size", int)
            self._size = max(self._size, 0)
        elif self._update_size: 
            self._size = 0
            self._update_size = True
        return self._size
    
    @property
    def shape(self) -> Tuple[int, ...]:
        if (
                self._data is not None and 
                ( self._update_shape or self._shape is None or self._shape == 0 )
        ):
            try: 
                self._shape     = self._data.shape
                self._shapes    = [ self._shape ]
            except:
                self._shape  = self._data[0].shape
                self._shapes = [ self._data[i].shape for i in range(len(self._data))]
            self._update_shape = False
        elif self._update_shape: 
            self._shape = 0,
            self._shapes = []
            self._update_shape = True
        return self._shape
    @property
    def shapes(self) -> List [ Tuple [ int, ... ]]:
        if (
            self._data is not None and 
            ( self._update_shape or self._shapes is None or self._shapes ==[])
        ):
            try: 
                self._shape     = self._data.shape
                self._shapes    = [ self._shape ]
            except:
                self._shape     = self._data[0].shape
                self._shapes    = [ self._data[i].shape for i in range(len(self._data))]
            self._update_shape  = False
        elif self._update_shape: 
            self._shape         = 0,
            self._shapes        = []
            self._update_shape  = True
        return self._shapes
            
    @property
    def data(self):
        return self._data
    
    @data.setter
    def data(self, value):
        self._data          = value
        self._update_size   = True
        self._update_shape  = True


In [5]:
#| export
@dataclass
class LRScheduler:
    lr              : float = None
    flag            : bool  = None
    name            : str   = None
    num_warmup_steps: int   = None

    def __post_init__(self):
        self.lr                 = self._check_lr(self.lr, 1e-5)
        self.flag               = self._check_flag(self.flag, False)
        self.name               = self._check_name(self.name, "OneCycleR")
        self.num_warmup_steps   = self._check_steps(self.num_warmup_steps, 0)

    # Validation methods
    def _check_lr(self, value, default):
        if not isinstance(value, (float, int)) or not math.isfinite(value) or value <= 0:
            warnings.warn(f"Invalid learning rate 'lr' ({value}). Using default: {default}")
            return default
        return float(value)

    def _check_flag(self, value, default):
        if not isinstance(value, bool):
            warnings.warn(f"Invalid type for 'flag' ({type(value)}). Using default: {default}")
            return default
        return value

    def _check_name(self, value, default):
        if not isinstance(value, str):
            warnings.warn(f"Invalid type for 'name' ({type(value)}). Using default: {default}")
            return default
        return value

    def _check_steps(self, value, default):
        if not isinstance(value, int) or value < 0:
            warnings.warn(f"Invalid type or negative value for 'num_warmup_steps' ({value}). Using default: {default}")
            return default
        return value

In [None]:
#| export
@dataclass
class EncoderOptimizer():
    criterion   : Optional   [ torch.nn.Module ]          = torch.nn.MSELoss
    optimizer   : Optional   [ torch.optim.Optimizer ]    = None
    lr          : Union      [ float, LRScheduler ]       = 1e-5

    def _post__init__(self):
        self.lr,_ = ut._check_value( self.lr, 1e-5, "lr", [ int, float ], False, True, False )

In [None]:
#| export
@dataclass
class Encoder():
    model               : Tuple [ 
                            MOMENTPipeline,
                            Learner,
                            moirai.MoiraiModule
                        ]                   = None
    input               : EncoderInput      = EncoderInput()
    mssg                : ut.Mssg           = ut.Mssg()
    cpu                 : bool              = False
    to_numpy            : bool              = False
    num_epochs          : int               = 1
    optim               : EncoderOptimizer  = EncoderOptimizer()
    mask_stateful       : bool              = False
    mask_future         : bool              = False
    mask_sync           : bool              = False
    eval_stats_pre      : AttrDict          = None
    eval_stats_post     : AttrDict          = None
    use_moment_masks    : bool              = False
    model_class         : str               = None
    time_flag           : bool              = False
    use_wandb           : bool              = False
    analysis_mode       : str               = 'online'
    splits              : Tuple             = None
    show_plot           : bool              = False
    norm_by_sample      : bool              = True
    norm_use_single_batch : bool            = True
    metrics             : List [ Callable ] = None
    #mvp_ws              : Tuple [ int, int ]= 0,0
    def __post_init__(self):
        self.model          , _ = ut._check_value(self.model, None, "model", [ MOMENTPipeline, Learner, moirai.MoiraiModule ], True, False, False, mssg = self.mssg)
        self.model              = self.set_model_(self.model)
        ## TODO: check how to do this check
        #self.input          , _ = ut._check_value(self.input, EncoderInput(), "input", EncoderInput, True)
        self.mssg           , _ = ut._check_value(self.mssg, ut.Mssg(), "mssg", ut.Mssg, mssg = self.mssg)
        self.cpu            , _ = ut._check_value(self.cpu, False, "cpu", bool, mssg = self.mssg)
        self.to_numpy       , _ = ut._check_value(self.to_numpy, False, "to_numpy", bool,  mssg = self.mssg)
        self.num_epochs     , _ = ut._check_value(self.num_epochs, 1, "num_epochs", int, False, True,  mssg = self.mssg)
        ## TODO: check how to do this check
        #self.optim          , _ = ut._check_value(self.optim, EncoderOptimizer(), "optim", EncoderOptimizer)
        self.mask_stateful  , _ = ut._check_value(self.mask_stateful, False, "mask_statefull", bool,  mssg = self.mssg)
        self.mask_future    , _ = ut._check_value(self.mask_future, False, "mask_future", bool,  mssg = self.mssg)
        self.mask_sync      , _ = ut._check_value(self.mask_sync, False, "mask_sync", bool,  mssg = self.mssg)
        self.eval_stats_pre , _ = ut._check_value(self.eval_stats_pre, None, "eval_stats_pre", AttrDict, True,  mssg = self.mssg)
        self.eval_stats_post, _ = ut._check_value(self.eval_stats_post, None, "eval_stats_post", AttrDict, True,  mssg = self.mssg)
        self.use_moment_masks, _= ut._check_value(self.use_moment_masks, False, "use_moment_masks", bool,  mssg = self.mssg)
        self.model_class        = None # Must be computed through get_model_class to avoid errors
        self.time_flag      , _ = ut._check_value(self.time_flag, False, "time_flag", bool,  mssg = self.mssg)
        self.show_plot      , _ = ut._check_value(self.show_plot, False, "show_plot", bool, mssg = self.mssg)
    
    def print(self, **kwargs):
        self.mssg.print(**kwargs)

    def get_model_class(self, force : bool = False): 
        if force or self.model_class is None:
            self.model_class = str(self.model.__class__)[8:-2]
        return self.model_class
    def set_model_(self, model):
        if model is not None:
            self.model          = model
            self.model_class    = self.get_model_class() 
            try: # Initially it may not be defined and that would result in an execution error
                self.fine_tune_     = self.set_fine_tune_()
            except:
                self.fine_tune_ = None
        return self.model
    
    def get_splits_(self, n_sample: int = None):
        self.mssg.initial_(ut.funcname())
        #TODO: add checks for datatype to ensure the dataset is not already windowed
        assert self.analysis_mode in [ 'ofline', 'online'], 'Invalid analysis mode'
        X = self.input.data if n_sample is None else self.input.data[n_sample]
        self.mssg.print(f"len(X)={len(X)}")
        match self.analysis_mode:
            case 'online':
                self.mssg.print("Online analysis", verbose_level = self.mssg.level+1)
                self.splits = TimeSplitter(valid_size = 0.2, show_plot = self.show_plot)(X)
            case 'offline':
                self.mssg.print("Offline analysis", verbose_level = self.mssg.level+1)
                self.splits = get_splits(np.arange(len(X)), valid_size=self.valid_size, show_plot = self.show_plot)
            case _:
                raise NotImplementedError(f"Encoderl{ut.funcname()} | Case {self.analysis_mode} not implemented. Use one of the following options: <online|offline>.")
        self.mssg.print(f"X~{X.shape}")
        self.mssg.print(f"Train: {len(self.splits[0])} | Test { len(self.splits[1])}")
        self.mssg.final()
        return X

    #TODO: poner los equivalentes para train, eval, get_embeddings, get_acts, etc.
    
    # Fine_tune_single_
    def fine_tune_moment_single_(self):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def fine_tune_mvp_single_(self):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def fine_tune_moirai_single_(self):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def fine_tune_single_(self):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")

    # Fine_tune_
    def fine_tune_moment_(self, eval_pre = False, eval_post = False, shot = True, time_flag = False, use_moment_masks = False): 
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def fine_tune_mvp_(self, eval_pre = False, eval_post = False, shot = True, time_flag = False): 
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def fine_tune_moirai_(self, eval_pre = False, eval_post = False, shot = True, time_flag = False): 
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def fine_tune_(self, eval_pre = False, eval_post = False, shot = True, time_flag = False):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def set_fine_tune_(self):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")
    def show_eval_stats(self):
        raise NotImplementedError(f"Encoder.{ut.funcname()} not yet implemented")

In [None]:
#| export
def set_fine_tune_single_(
    self: Encoder
) -> Callable:
    self.mssg.initial_(ut.funcname())
    model_class = self.get_model_class()
    self.mssg.print(f"Model class: {model_class}")
    match model_class:
        case "momentfm.models.moment.MOMENTPipeline":
            self.fine_tune_single_ = self.fine_tune_moment_single_
        case "fastai.learner.Learner":
            self.fine_tune_single_ = self.fine_tune_mvp_single_
        case "uni2ts.model.moirai.module.MoiraiModule":
            self.fine_tune_single_ = self.fine_tune_moirai_single_
        case _:
            self.mssg.print(f"Fine-tune single shot implementation is not yet implemented for {self.model_class}.", verbose_level = self.mssg.level+1)
            raise NotImplementedError(f"fine_tune_single_ | Not yet implemented for {self.model_class}")
    self.mssg.final(ut.funcname())
    return(self.fine_tune_single_)
Encoder.set_fine_tune_single_ = set_fine_tune_single_

In [None]:
#| export
def set_fine_tune_(
    self: Encoder
) -> Callable:
    self.mssg.initial_("set_fine_tune_")
    model_class = self.get_model_class()
    self.mssg.print(f"Model class: {model_class}")
    match model_class:
        case "momentfm.models.moment.MOMENTPipeline":
            self.mssg.print(f"Moment")
            self.fine_tune_ = self.fine_tune_moment_
        case "fastai.learner.Learner":
            self.mssg.print(f"MVP")
            self.fine_tune_ = self.fine_tune_mvp_
        case "uni2ts.model.moirai.module.MoiraiModule":
            self.mssg.print(f"Moirai")
            self.fine_tune_ = self.fine_tune_moirai_
        case _:
            self.mssg.print(f"Fine-tune implementation is not yet implemented for {self.model_class}.", verbose_level = self.mssg.level+1)
            raise NotImplementedError(f"fine_tune | Not yet implemented for {self.model_class}")
    self.mssg.final(ut.funcname())
    return(self.fine_tune_)
Encoder.set_fine_tune_ = set_fine_tune_

In [10]:
#| export
def show_eval_stats(
    self            : Encoder, 
    print_to_path   : bool      = None, 
    print_path      : str       = None, 
    print_mode      : str       = None,
    eval_pre        : bool = False,
    eval_post       : bool = False,
    eval_stats_pre  : AttrDict = None,
    eval_stats_post : AttrDict = None,
    func_name       : str = ""
):
    self.mssg.print(f"{func_name} | Evaluation summary")
    self.eval_stats_pre = self.eval_stats_pre if eval_stats_pre is None else eval_stats_pre
    self.eval_stats_post = self.eval_stats_post if eval_stats_post is None else eval_stats_post
    self.mssg.to_path = self.mssg.to_path if print_to_path is None else print_to_path
    self.mssg.path = self.mssg.path if print_path is None else print_path
    self.mssg.mode = self.mssg.mode if print_mode is None else print_mode        
    if (eval_pre):
        self.mssg.print(f"Eval pre: ")
        show_attrdict(
            self.eval_stats_pre,
            print_to_path   = self.mssg.to_path,
            print_path      = self.mssg.path,
            print_mode      = self.mssg.mode
        )
    if eval_post:
        self.mssg.print(f"Eval post: ")
        show_attrdict(
            self.eval_stats_post,
            print_to_path   = self.mssg.to_path,
            print_path      = self.mssg.path,
            print_mode      = self.mssg.mode 
        )
Encoder.show_eval_stats = show_eval_stats

### Architectures

In [11]:
#|export 
class DCAE_torch(Module):
    def __init__(self, c_in, seq_len, delta, nfs=[64, 32, 12], kss=[10, 5, 5],
                 pool_szs=[2,2,3], output_fsz=10):
        """
        Create a Deep Convolutional Autoencoder for multivariate time series of `d` dimensions,
        sliced with a window size of `w`. The parameter `delta` sets the number of latent features that will be
        contained in the Dense layer of the network. The the number of features
        maps (filters), the filter size and the pool size can also be adjusted."
        """
        assert all_equal([len(x) for x in [nfs, kss, pool_szs]], np.repeat(len(nfs), 3)), \
            'nfs, kss, and pool_szs must have the same length'
        assert np.prod(pool_szs) == nfs[-1], \
            'The number of filters in the last conv layer must be equal to the product of pool sizes'
        assert seq_len % np.prod(pool_szs) == 0, \
            'The product of pool sizes must be a divisor of the window size'
        layers = []
        for i in range_of(kss):
            layers += [Conv1d(ni=nfs[i-1] if i>0 else c_in, nf=nfs[i], ks=kss[i]),
                       nn.MaxPool1d(kernel_size=pool_szs[i])]
        self.downsample = nn.Sequential(*layers)
        self.bottleneck = nn.Sequential(OrderedDict([
            ('flatten', nn.Flatten()),
            ('latent_in', nn.Linear(seq_len, delta)),
            ('latent_out', nn.Linear(delta, seq_len)),
            ('reshape', Reshape(nfs[-1], seq_len // np.prod(pool_szs)))
        ]))
        layers = []
        for i in reversed(range_of(kss)):
            layers += [Conv1d(ni=nfs[i+1] if i != (len(nfs)-1) else nfs[-1],
                              nf=nfs[i], ks=kss[i]),
                       nn.Upsample(scale_factor=pool_szs[i])]
        layers += [Conv1d(ni=nfs[0], nf=c_in, kernel_size=output_fsz)]
        self.upsample = nn.Sequential(*layers)


    def forward(self, x):
        x = self.downsample(x)
        x = self.bottleneck(x)
        x = self.upsample(x)
        return x

In [12]:
#| hide
foo = torch.rand(3, 1, 48)
m = DCAE_torch(c_in=foo.shape[1], seq_len=foo.shape[2], delta=12)
m(foo).shape

torch.Size([3, 1, 48])

### Dictionary to get the default backbone modules to get the embeddings from

In [13]:
#| export
ENCODER_EMBS_MODULE_NAME = {
    InceptionTimePlus: 'backbone', # for mvp based models
    DCAE_torch: 'bottleneck.latent_in'#,
    #MoiraiForecast: 'mask_encoding' #TODO: check
    
}

## Get activations

In [14]:
#| export 
def kwargs_to_gpu_(**kwargs):
    for key in kwargs:
        try: #if not able to be moved, just not move it
            kwargs[key] = kwargs[key].to("cuda")
        except:
            continue
    
def kwargs_to_cpu_(**kwargs):
    for key in kwargs:
        try: #if not able to be moved, just not move it
            kwargs[key] = kwargs[key].cpu()
        except:
            continue
   

In [15]:
#| export
def get_acts(
    model : torch.nn.Module, 
    module: torch.nn.Module, 
    cpu   : bool, 
    verbose : int = 0,
    retry: bool = False,
    acts_indices: List [ int ] = None,
    #- Printing options for debugging
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a',
    continue_if_fail: bool          = False,
    **model_kwargs #Parameters of the model
):
    if verbose > 0:
        ut.print_flush(f"--> get acts | acts indices: {acts_indices}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    if cpu:
        if verbose > 0: ut.print_flush(f"get acts | Moving to cpu", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
        for key in model_kwargs:
            try: #if not able to be moved, just not move it
                model_kwargs[key] = model_kwargs[key].cpu()
            except:
                continue
        model.to("cpu")
    else:
        if verbose > 0: ut.print_flush(f"get acts | Moving to gpu", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
        for key in model_kwargs:
            try: #if not able to be moved, just not move it
                model_kwargs[key] = model_kwargs[key].to("cuda")
            except:
                continue
        model.to("cuda")
    if verbose > 0: ut.print_flush(f"get acts | Add hooks", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    h_act = hook_outputs([module], detach = True, cpu = cpu, grad = False)
    with torch.no_grad():
        if verbose > 0: ut.print_flush(f"get acts | --> Run forward", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
        if retry:
            if verbose > 0: ut.print_flush(f"get acts | Retry", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
            try: 
                preds = model.eval()(**model_kwargs)
            except Exception as e:
                ut.print_flush(f"get acts | Retry | Error: {e}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                ut.print_flush(f"get acts | Retry | Kwargs: {model_kwargs}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                if not cpu:
                    ut.print_flush(f"get acts | Retry | Moving to cpu", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                    for key in model_kwargs:
                        try: #if not able to be moved, just not move it
                            model_kwargs[key] = model_kwargs[key].cpu()
                        except:
                            continue
                    model.to("cpu")
                    if verbose > 0: ut.print_flush(f"get acts | Retry | cpu", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                    ut.print_flush(f"get acts | Retry | Get acts", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                    preds = model.eval()(**model_kwargs)
        else:
            if verbose > 2: ut.print_flush(f"get acts | No Retry | Get acts | model kwargs: {model_kwargs}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
            preds = model.eval()(**model_kwargs)
    if acts_indices is None:
        res = [o.stored for o in h_act]
    else: 
        stored = [o.stored for o in h_act]
        res = [stored[i] for i in acts_indices]
        if len(acts_indices) == 1:
            res = res[0]
        del stored
    if verbose > 0: ut.print_flush(f"get acts | Run forward -->", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    if verbose > 0:ut.print_flush(f"get acts -->", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    return res

In [16]:
#| export
def get_acts_moment(
    enc_learn, 
    cpu             : bool          = False, 
    verbose         : int           = 0, 
    y               : List [ float ]= [], 
    mask                            = None, 
    padd_step       : int           = 100, 
    # Parameters for avoiding errors
    retry           : bool          = False, 
    max_trials      : int           = 5,
    # Activation selector (various vectors in the acts)
    acts_indices    : List [ int ]  = [0],
    #- Printing options for debugging
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a',
    continue_if_fail: bool          = False
):
    success = False 
    trial = 0
    embs = None
    while not success and trial < max_trials:
        trial += 1
        try:
            if verbose > 0: ut.print_flush(f"get_acts_moment | Trial {trial} | x_enc ~ {y.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
            embs = get_acts(
                model = enc_learn,
                #module = enc_learn.encoder.dropout,
                module = enc_learn.head.dropout,
                cpu = cpu,
                verbose = 0,
                x_enc = y,
                retry = retry,
                acts_indices = acts_indices,
                mask = mask,
                print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, print_time = print_to_path
            )
            success = True
            if verbose > 0 and acts_indices == [0] : ut.print_flush(f"get_acts_moment | Trial {trial} | embs ~ {embs.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
        except Exception as e:
            if trial == max_trials - 1 : raise
            if verbose > 0:
                ut.print_flush(f"get_acts_moment | Trial {trial} | About to pad X (encoder input) | exception {e} | padd step: {padd_step}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                ut.print_flush(f"get_acts_moment | Trial {trial} | y ~ {y.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
            if "tensor a" in str(e) and "tensor b" in str(e):
                match = re.search(r'tensor a \((\d+)\) must match the size of tensor b \((\d+)\)', str(e))
                tensor_a_size = int(match.group(1))
                tensor_b_size = int(match.group(2))
                padd = True
                if trial > 1: 
                    if verbose > 0: ut.print_flush(f"------------------- Trial {trial}  -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                    if tensor_a_size > tensor_a_size_old:
                        if verbose > 0:  ut.print_flush(f"------------------- Trial {trial} | a > a_old -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                        padd = False
                        y = y [ ..., : tensor_a_size - tensor_b_size]
                        if verbose > 0: ut.print_flush(f"------------------- Trial {trial} |a > a_old | Reduced |  y ~ {y.shape} -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                if padd:
                    if verbose > 0: ut.print_flush(f"------------------- Trial {trial} | Padd -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                    if tensor_a_size > tensor_b_size: 
                        if verbose > 0: ut.print_flush(f"------------------- Trial {trial} | Padd | a > b -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                        padd_step = tensor_a_size - tensor_b_size
                    y = torch.nn.functional.pad(y,(0,padd_step))
                tensor_a_size_old = tensor_a_size
            else:
                if verbose > 0: ut.print_flush("Not the usual error. No padding, just fail", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                raise
                
    return embs

In [17]:
#| export
def sure_eval_moment(
    enc_learn, 
    cpu, 
    verbose, 
    y, 
    input_mask = None, 
    mask = None, 
    padd_step = 100, 
    retry = False, 
    max_trials = 5, 
    acts_indices = [0],
    #- Printing options for debugging
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a',
    continue_if_fail: bool          = False
):
    if verbose > 0: ut.print_flush(f"---> sure_eval_moment", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    device = "cpu" if cpu else torch.cuda.current_device()
    y_copy = y.clone()
    y_copy.to("cpu")
    if verbose > 0: ut.print_flush(f"sure_eval_moment | cpu | {cpu} | device | {device}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path) 
    success = False 
    trial = 0
    output = None
    
    while not success and trial < max_trials:
        trial += 1
        try:
            if verbose > 0: ut.print_flush(f"sure_eval_moment | Trial {trial} | x_enc ~ {y.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
            if input_mask is not None: input_mask = input_mask.to(device)
            if mask is not None: mask = mask.to(device)
            y = y.to(device)
            enc_learn = enc_learn.to(device)
            if verbose > 0: 
                ut.print_flush(f"sure_eval_moment | Trial {trial} | device {device} | input_mask~{input_mask.shape} device: {input_mask.device if input_mask is not None else 'None'}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                ut.print_flush(f"sure_eval_moment | Trial {trial} | device {device} | mask device~{mask.shape}: {mask.device if mask is not None else 'None'}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                ut.print_flush(f"sure_eval_moment | Trial {trial} | device {device} | y~{y.shape} device: {y.device}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
            output = enc_learn(x_enc = y, input_mask = input_mask, mask = mask)
            success = True
            if verbose > 0 and acts_indices == [0] : 
                ut.print_flush(f"sure_eval_moment | Trial {trial} | embs ~ {embs.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        except Exception as e:
            if verbose > 0:
                ut.print_flush(f"sure_eval_moment | Trial {trial} | About to pad X (encoder input) | exception {e} | padd step: {padd_step}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                ut.print_flush(f"sure_eval_moment | Trial {trial} | y ~ {y.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                traceback.print_exc()
            if "tensor a" in str(e) and "tensor b" in str(e) and "dimension" in str(e):
                match = re.search(r'tensor a \((\d+)\) must match the size of tensor b \((\d+)\) at non-singleton dimension (\d+)', str(e))
                tensor_a_size = int(match.group(1))
                tensor_b_size = int(match.group(2))
                dimension = int(match.group(3))
                match dimension:
                    case 2 | 1:
                        padd = True
                        if trial > 1: 
                            if verbose > 0: ut.print_flush(f"------------------- Trial {trial}  -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                            if tensor_a_size > tensor_a_size_old:
                                if verbose > 0: ut.print_flush(f"------------------- Trial {trial} | a > a_old -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                                padd = False
                                y = y [ ..., : tensor_a_size - tensor_b_size]
                                if verbose > 0: ut.print_flush(f"------------------- Trial {trial} |a > a_old | Reduced |  y ~ {y.shape} -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                        if padd:
                            if verbose > 0: ut.print_flush(f"------------------- Trial {trial} | Padd -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                            if tensor_a_size > tensor_b_size: 
                                if verbose > 0: ut.print_flush(f"------------------- Trial {trial} | Padd | a > b -----------------", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                                padd_step = tensor_a_size - tensor_b_size
                            y = torch.nn.functional.pad(y,(0,padd_step))
                        tensor_a_size_old = tensor_a_size
                    #case 1: 
                    #    if verbose > 0:
                    #        ut.print_flush(f"sure_eval_moment | Trial {trial} | Error dimension 0 | mask ~ {mask.shape} | mask_input ~ {input_mask.shape} | batch ~ {y.shape}")
                    #        if mask.shape[1] < y.shape[2]: mask = torch.nn.functional.pad(mask,(0,y.shape[2]-mask.shape[1]))
                    #        if input_mask.shape[2] < y.shape[2]: mask = torch.nn.functional.pad(input_mask,(0,y.shape[2]-input_mask.shape[2]))

                    case 0:
                        if verbose > 0: 
                            ut.print_flush(f"sure_eval_moment | Trial {trial} | Error dimension 0 | mask ~ {mask.shape} | mask_input ~ {input_mask.shape} | batch ~ {y.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)                    
                        if mask.shape[0] > y.shape[0]:
                            mask = mask[:y.shape[0]]
                        if input_mask.shape[0] > y.shape[0]:
                            input_mask = input_mask[:y.shape[0]]
                        
                        if mask.shape[0] < y.shape[0]:
                            extra_rows_shape = (-mask.shape[0]+y.shape[0],mask.shape[1])
                            if verbose > 0: ut.print_flush(f"sure_eval_moment | Trial {trial} | Mask lower than batch | rows to add: {extra_rows_shape }", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                            extra_rows = torch.zeros(extra_rows_shape, dtype = torch.float32)
                            mask = torch.cat((mask, extra_rows), dim=0)
                        if input_mask.shape[0] < y.shape[0]:
                            extra_rows_shape = (-input_mask.shape[0]+y.shape[0],y.shape[1], y.shape[2])
                            if verbose > 0: ut.print_flush(f"sure_eval_moment | Trial {trial} | Mask lower than batch | rows to add: {extra_rows_shape }", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                            extra_rows = torch.zeros(extra_rows_shape, dtype = torch.float32)
                            input_mask = torch.cat((input_mask, extra_rows), dim=0)
            else:
                if verbose > 0: 
                    ut.print_flush("Not the usual error. No padding, just fail", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
                if not continue_if_fail: raise
        #if verbose > 0: ut.print_flush(f"sure_eval_moment | output {output.__class__} | enc_learn {enc_learn.__class__} -->", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        if verbose > 0: ut.print_flush(f"sure_eval_moment | output {output.__class__} -->", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    y = y_copy
    if not cpu: y.to("cuda")
    
    return output, enc_learn

## Getting the embeddings (activations) from the encoder

In [19]:
#| export
def get_enc_embs_ensure_batch_size_(
    dls        : TSDataLoaders,
    batch_size : int = None,
    verbose    : int = 0
) -> None:
    if batch_size is None:
        if verbose > 1: 
            ut.print_flush(f"[ Get Encoder Embeddings Ensure Batch Size ] No batch size proposed", verbose = verbose)
        if dls.bs == 0: 
            if verbose > 1: 
                ut.print_flush(f"[ Get Encoder Embeddings Ensure Batch Size ] Using value 64 as 0 is not a valid value.", verbose = verbose)
            enc_learn.dls.bs = 64
        elif verbose > 1: 
            ut.print_flush(f"[ Get Encoder Embeddings Ensure Batch Size ] Using the original value: {dls.bs}", verbose = verbose)
    else:
        dls.bs = batch_size
        if verbose > 1: 
            ut.print_flush(f"[ Get Encoder Embeddings Ensure Batch Size ] Batch size proposed. Using {dls.bs}", verbose = verbose)

In [20]:
#| export
def get_enc_embs_MVP(
    X               : List [ List [ List [ float ] ] ], 
    enc_learn       : Learner, 
    module          : str  = None, 
    cpu             : bool = False, 
    average_seq_dim : bool = True, 
    to_numpy        : bool = True,
    batch_size      : int  = None,
    verbose         : int  = 0
):
    """
        Get the embeddings of X from an encoder, passed in `enc_learn as a fastai
        learner. By default, the embeddings are obtained from the last layer
        before the model head, although any layer can be passed to `model`.
        Input
        - `cpu`: Whether to do the model inference in cpu of gpu (GPU recommended)
        - `average_seq_dim`: Whether to aggregate the embeddings in the sequence dimensions
        - `to_numpy`: Whether to return the result as a numpy array (if false returns a tensor)
        - `batch_size`: force data loader to use the input batch size
        - `verbose`: print flag. More big, more information.
    """
    
    if cpu:
        if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] CPU")
        enc_learn.dls.cpu()
        enc_learn.cpu()
    else:
        if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] --> GPU")
        if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] GPU | Ensure empty cache")
        torch.cuda.empty_cache()
        if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] GPU | Move & exec into CUDA")
        enc_learn.dls.cuda()
        enc_learn.cuda()
        if torch.cuda.is_available():
            if verbose > 1: 
                ut.print_flush("[ Get Encoder Embeddings ] GPU | CUDA is available")
                ut.print_flush(f"[ Get Encoder Embeddings ] GPU | CUDA is available | current device id {torch.cuda.current_device()}")
                ut.print_flush(f"[ Get Encoder Embeddings ] GPU | CUDA is available | current device name {torch.cuda.get_device_name(torch.cuda.current_device())}")            
        else:
            if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] GPU | CUDA is not available")
        if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] GPU -->")

    #if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] Ensure the correct batch size")
    #get_enc_embs_ensure_batch_size_(enc_learn.dls, batch_size, verbose)
    
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] Set dataloader from X (enc_learn does not contain dls)")
    aux_dl = enc_learn.dls.valid.new_dl(X=X)
    get_enc_embs_ensure_batch_size_(aux_dl, batch_size, verbose)
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] Get module")
    module = nested_attr(enc_learn.model,ENCODER_EMBS_MODULE_NAME[type(enc_learn.model)]) if module is None else module
    
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads ")
    if verbose > 1: ut.print_flush(f"[ Get Encoder Embeddings ] get_acts_and_grads bs = {aux_dl.bs}")
    
    embs = [
        get_acts_and_grads(
            model   = enc_learn.model,
            modules = module,
            x       = xb[0], 
            cpu     = cpu
        )[0] 
        for xb in aux_dl
    ]
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads | --> Concat")
    if not cpu:
        if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads | Concat | Check neccesary & free memory")
        total_emb_size = sum([emb.element_size() * emb.nelement() for emb in embs])
        free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
        if (total_emb_size < free_memory):
            if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads | Concat | Check neccesary & free memory | Fits in GPU -> Computing in GPU")
            embs=[emb.cuda() for emb in embs]
        else:
            if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads | Concat | Check neccesary & free memory | Does not fit in GPU -> Computing in CPU")
            embs=[emb.cpu() for emb in embs]
    if verbose > 1: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads | Concat | to_concat")
    embs = to_concat(embs)
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] get_acts_and_grads | Concat -->")
    
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] Reduce to 2 dimensions.")
    if embs.ndim == 3 and average_seq_dim: embs = embs.mean(axis=2)
    if verbose > 0: ut.print_flush("[ Get Encoder Embeddings ] Ensure CPU saving & numpy format")
    if to_numpy: embs = embs.numpy() if cpu else embs.cpu().numpy()
    return embs

In [21]:
#| export

def get_enc_embs_MVP_set_stride_set_batch_size(
    X                  : List [ List [ List [ float ] ] ], 
    enc_learn          : Learner, 
    stride             : int, 
    batch_size         : int, 
    module             : str  = None, 
    cpu                : bool = False, 
    average_seq_dim    : bool = True, 
    to_numpy           : bool = True, 
    verbose            : int  = 0, 
    time_flag          : bool = False, 
    chunk_size         : int  = 0, 
    check_memory_usage : bool = False
):
    """
        Get the embeddings of X from an encoder, passed in `enc_learn as a fastai
        learner. By default, the embeddings are obtained from the last layer
        before the model head, although any layer can be passed to `model`.
        Input
        - `X`: encoder input
        - `enc_learn`: trained encoder
        - `stride`: stride used for the training. Neccesary for adjusting the encoder input
        - `batch_size`: value to force the dataloader to use.
        - `module`: for geting the embeddings of an specific layer.
        - `cpu`: Whether to do the model inference in cpu of gpu (GPU recommended)
        - `average_seq_dim`: Whether to aggregate the embeddings in the sequence dimensions
        - `to_numpy`: Whether to return the result as a numpy array (if false returns a tensor)
        - `verbose`: For printing messages. More big, more messages.
        - `time_flag`: To take note of the execution time required by this function
        - `chunk_size`: For spliting the embedings reading in batches of `chunk_size` size.
        - `check_memory_usage`: For showing messages of the current state of the memory.
    """
    if time_flag:
        t_start = time.time()
    if verbose > 0:
        ut.print_flush("--> get_enc_embs_MVP_set_stride_set_batch_size", verbose = verbose)
    if check_memory_usage: gpu_memory_status()
    X = X[::stride]
    enc_learn.dls.bs = batch_size 

    get_enc_embs_ensure_batch_size_(enc_learn.dls, batch_size, verbose)
    
    if verbose > 0: ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Check CUDA | X ~ {X.shape[0]}", verbose = verbose)
    if cpu:
        if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Get enc embs CPU")
        enc_learn.dls.cpu()
        enc_learn.cpu()
    else:
        if torch.cuda.is_available():
            if verbose > 0: 
                ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | CUDA device id: {torch.cuda.current_device()}", verbose = verbose)
                ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | CUDA device name: {torch.cuda.get_device_name(torch.cuda.current_device())}", verbose = verbose)
                ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Ensure empty cache & move 2 GPU", verbose = verbose)
            torch.cuda.empty_cache()
            enc_learn.dls.cuda()
            enc_learn.cuda()
        else:
            if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | No cuda available. Set CPU = true")
            cpu = True
            
    get_enc_embs_ensure_batch_size_(enc_learn.dls, batch_size, verbose)

    if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Set dataset from X (enc_learn does not contain dls)", verbose = verbose)
    aux_dl = enc_learn.dls.valid.new_dl(X=X)
    aux_dl.bs = enc_learn.dls.bs if enc_learn.dls.bs>0 else 64
    if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Get module", verbose = verbose)
    module = nested_attr(enc_learn.model,ENCODER_EMBS_MODULE_NAME[type(enc_learn.model)]) if module is None else module
    
    if verbose > 0: 
        #ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | module ", module)
        ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | aux_dl len {len(aux_dl)}", verbose = verbose)
        ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | aux_dl.batch_len {len(next(iter(aux_dl)))}", verbose = verbose)
        ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | aux_dl.bs {aux_dl.bs}", verbose = verbose)
        if (not cpu):
            total = torch.cuda.get_device_properties(device).total_memory
            used = torch.cuda.memory_allocated(torch.cuda.current_device())
            reserved = torch.cuda.memory_reserved(torch.cuda.current_device())
            ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | total_mem {total}", verbose = verbose)
            ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | used_mem {used}", verbose = verbose)
            ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | reserved_mem {reserved}" ,verbose = verbose)
            ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | available_mem {total-reserved}", verbose = verbose)
            sys.stdout.flush()
                                              
    if (cpu or ( chunk_size == 0 )):
        embs = [
            get_acts_and_grads(
                model=enc_learn.model,
                modules=module, 
                x=xb[0], 
                cpu=cpu
            )[0] 
            for xb in aux_dl
        ]
        if not cpu: embs=[emb.cpu() for emb in embs]
    else:
        embs = []
        total_chunks=max(1,round(len(X)/chunk_size))
        if verbose > 0: ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | aux_dl len | {str(len(X))}  chunk size: {str(chunk_size) } => { str(total_chunks) }  chunks", verbose = verbose)
        for i in range(0, total_chunks):
            if verbose > 0: 
                ut.print_flush(f"get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | Chunk [ {str(i)}/{str(total_chunks)}] => {str(round(i*100/total_chunks))}%", verbose = verbose)
                sys.stdout.flush()
            chunk = [batch for (n, batch) in enumerate(aux_dl) if (chunk_size*i <= n  and chunk_size*(i+1) > n) ]
            chunk_embs = [
                get_acts_and_grads(
                    model=enc_learn.model,
                    modules=module,
                    x=xb[0], 
                    cpu=cpu
                )[0]
                for xb in chunk
            ]
            # Mueve los embeddings del bloque a la CPU
            chunk_embs = [emb.cpu() for emb in chunk_embs]
            embs.extend(chunk_embs)
            torch.cuda.empty_cache()
        if verbose > 0: 
            ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Get acts and grads | 100%", verbose = verbose)
            sys.stdout.flush()
    
    if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | concat embeddings", verbose = verbose)
    
    embs = to_concat(embs)
    
    if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Reduce", verbose = verbose)
    
    if embs.ndim == 3 and average_seq_dim: embs = embs.mean(axis=2)
    
    if verbose > 0: ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size | Convert to numpy", verbose = verbose)
    
    if to_numpy: 
        if cpu or chunk_size > 0:
            embs = embs.numpy() 
        else: 
            embs = embs.cpu().numpy()
            torch.cuda.empty_cache()
    if time_flag:
        t = time.time()-t_start
        if verbose > 0:
            ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size " + str(t) + " seconds -->", verbose = verbose)
        else:
            ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size " + str(t) + " seconds", verbose = verbose)
    if check_memory_usage: gpu_memory_status()
    if verbose > 0: 
        ut.print_flush("get_enc_embs_MVP_set_stride_set_batch_size -->", verbose = verbose)
    return embs

In [22]:
#| export
def get_enc_embs_moment(
    X               : List [ List [ List [ float ] ] ], 
    enc_learn       : Learner, 
    cpu             : bool = False, 
    to_numpy        : bool = True,
    verbose         : int  = 0,
    average_seq_dim : bool = True
):
    if verbose > 0: 
        ut.print_flush("--> get_enc_embs_moment", verbose = verbose)
    # Move tensor and model to GPU
    if cpu or not torch.cuda.is_available():
        if verbose > 0: 
            ut.print_flush("get_enc_embs_moment | Using CPU (maybe no cuda available)", verbose = verbose)
        cpu = True
        enc_learn.cpu()
    else:
        if verbose > 0: 
            ut.print_flush("get_enc_embs_moment | Using CUDA", verbose = verbose)
        enc_learn.to("cuda")
    if verbose > 0: ut.print_flush("get_enc_embs_moment | Convert y", verbose = verbose)
    enc_learn.eval()
    if cpu:
        y = torch.from_numpy(X).cpu().float()
    else:
        y = torch.from_numpy(X).to("cuda").float()
    # Get output
    with torch.no_grad():
        if verbose > 0: 
            ut.print_flush("get_enc_embs_moment | Get outputs", verbose = verbose)
        outputs = enc_learn(y)
        if verbose > 0:
            ut.print_flush(f"get_enc_embs_moment | Final shape: X ~ {y.shape}", verbose = verbose)
                
    #| move tensors and models back to CPU
    if not cpu:
        y = y.detach().cpu().numpy()
    if verbose > 0: 
        ut.print_flush("get_enc_embs_moment | Get Embeddings", verbose = verbose)
    embeddings = outputs.embeddings.detach().cpu()
    if average_seq_dim: 
        embeddings = embeddings.mean(dim = 1)
    if to_numpy:
        embeddings = embeddings.cpu().numpy()
    if verbose > 0: 
        ut.print_flush("get_enc_embs_moment -->", verbose = verbose)
    return embeddings

In [23]:
#| export
def get_enc_embs_moment_reconstruction(
    X               : List [ List [ List [ float ] ] ], 
    enc_learn       : Learner, 
    cpu             : bool          = False, 
    to_numpy        : bool          = True,
    verbose         : int           = 0,
    average_seq_dim : bool          = True,
    padd_step       : int           = 2,
    #- Printing options for debugging
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a',
    continue_if_fail: bool          = False
):
    """
    For reconstruction sometimes mask get invalid values
    To avoid them, the last dimension (sequence length) is padded with 0's until the error is skippedd
    It should only get one iteration as it seems to be some MOMENT internal configuration for patches.
    """
    if cpu:
        enc_learn.cpu()
        y = torch.from_numpy(X).cpu().float()
    else:
        enc_learn.to("cuda")
        y = torch.from_numpy(X).to("cuda").float()
    embs = get_acts_moment(
        enc_learn       = enc_learn, 
        cpu             = cpu, 
        verbose         = verbose, 
        y               = y, 
        mask            = None,
        padd_step       = padd_step,
        retry           = False ,
        max_trials      = 5,
        print_to_path   = print_to_path, print_path = print_path, print_mode = print_mode
    )
    if average_seq_dim: 
        embs = embs.mean(dim = 1).mean(dim = 1)
    if to_numpy:
        embs = embs.cpu().numpy()
    return embs

---> TODO: averiguar de qué module salen realmente los embeddings y usar el get_acts_and_grads como en MVP <---

In [25]:
#| export
import torch.profiler as profiler

In [26]:
#| export
def watch_gpu(func, **kwargs):
    """
    Wrapper to execute GPU profiler
    Parameters: 
    - func: function to monitor
    - kwargs: func parameters
    Returns:
    - result of /func/.
    """
    with profiler.profile(
        activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA],
        schedule=profiler.schedule(wait=1, warmup=1, active=3, repeat=2),  # Configuración de ciclos
        on_trace_ready=profiler.tensorboard_trace_handler('./log_dir'),  # Guarda los resultados en un archivo para visualización
        record_shapes=True,  # Registra la forma de los tensores
        profile_memory=True,  # Perfil de memoria
        with_stack=True  # Incluye la información de la pila
    ) as prof:
        # Ejecuta la función dentro del perfilador
        result = func(**kwargs)
    
    # Mostrar el uso de la GPU durante y después de la ejecución
    ut.print_flush(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))
    return result

In [27]:
#| export
def get_enc_embs_moirai(
    enc_input       : List [ List [ List [ Float ] ] ], 
    enc_model       : moirai.MoiraiModule, 
    cpu             : False,
    average_seq_dim : bool = True, 
    verbose         : int  = 0,
    to_numpy        : bool = True,
    patch_size      : int  = 8,
    time            : bool = False
):
    mssg = ut.Mssg()
    if time: 
        timer = ut.Time(mssg = mssg)
        timer.start()
    if verbose > 0: 
        ut.print_flush("--> get_enc_embs_moirai", verbose = verbose)
    # Move tensor and model to GPU
    past_target = einops.rearrange(
        torch.as_tensor(enc_input, dtype = torch.float32),
        "n_windows n_vars window_size -> n_windows window_size n_vars"
    )
    if cpu or not torch.cuda.is_available():
        if verbose > 0: ut.print_flush("get_enc_embs_moirai | Using CPU (maybe no cuda available)", verbose = verbose)
        cpu = True
        enc_model.cpu()
        past_target.cpu()
    else:
        if verbose > 0: ut.print_flush("get_enc_embs_moirai | Using CUDA", verbose = verbose)
        enc_model.to("cuda")
        past_target.to("cuda")
        
    if verbose > 0: ut.print_flush("get_enc_embs_moirai | Get Outputs", verbose = verbose)

    
    past_observed_target = torch.ones_like(past_target, dtype=torch.bool)
    past_is_pad = torch.zeros_like(past_target, dtype=torch.bool)[...,:,-1] # Kill last dimension

    if (verbose > 1):
        ut.print_flush(f"--> get_enc_embs_moirai | past_target ~ {past_target.shape}")
        ut.print_flush(f"--> get_enc_embs_moirai | past_observed_target ~ {past_observed_target.shape}")
        ut.print_flush(f"--> get_enc_embs_moirai | past_is_pad ~ {past_is_pad.shape}")
        ut.print_flush(f"--> get_enc_embs_moirai | Auxiliar model")
        ut.print_flush(f"--> get_enc_embs_moirai | Auxiliar model | Before Memory:")
        gpu_memory_status()
    
    # Auxiliar model for conversions just to ensure correct sizes
    #not neccesary, is the same module initially downloaded...
    #module = moirai.MoiraiModule.from_pretrained(f"Salesforce/moirai-1.1-R-small")
    
    forecast_model =  moirai_forecast.MoiraiForecast(
        module=enc_model,
        prediction_length=past_target.shape[2], #random, just for getting the model
        context_length=past_target.shape[1],
        patch_size=patch_size,
        num_samples=100, #Random, is the number of forecasting, not interesting for us
        target_dim=past_target.shape[2],
        feat_dynamic_real_dim=0,
        past_feat_dynamic_real_dim=0,
    )
    
    if verbose > 0:
        ut.print_flush(f"--> get_enc_embs_moirai | Auxiliar model | After Memory:")
        gpu_memory_status()
        ut.print_flush(f"--> get_enc_embs_moirai | Convert sizes")
    (
    target,
    observed_mask,
    sample_id,
    time_id,
    variate_id,
    prediction_mask,
    ) = forecast_model._convert(
        patch_size,
        past_target,
        past_observed_target,
        past_is_pad
    )
    if verbose > 1:
        ut.print_flush(f"get_enc_embs_moirai | target ~ {target.shape}")
        ut.print_flush(f"get_enc_embs_moirai | observed_mask ~ {observed_mask.shape}")
        ut.print_flush(f"get_enc_embs_moirai | sample_id ~ {sample_id.shape}")
        ut.print_flush(f"get_enc_embs_moirai | time_id ~ {time_id.shape}")
        ut.print_flush(f"get_enc_embs_moirai | variate_id ~ {variate_id.shape}")
        ut.print_flush(f"get_enc_embs_moirai | prediction_mask ~ {prediction_mask.shape}")
        gpu_memory_status()
    forecast_model = None
    torch.cuda.empty_cache()
    if verbose > 0:
        ut.print_flush(f"--> get_enc_embs_moirai | Delete Auxiliar model | After Memory:")
        gpu_memory_status()
    
    model_kwargs={
        'target': target, 
        'observed_mask': observed_mask,
        'sample_id': sample_id,
        'time_id': time_id,
        'variate_id': variate_id,
        'prediction_mask': prediction_mask,
        'patch_size': torch.ones_like(sample_id, dtype = torch.float32)*patch_size
    } 
    if verbose > 0: 
        ut.print_flush(f"get_enc_embs_moirai | About to get activations")
    acts = get_acts(
        model  = enc_model, 
        module = enc_model.encoder.norm, 
        cpu    = cpu,
        verbose = verbose,
        retry = True,
        acts_indices = [0],
        **model_kwargs #Parameters of the model
    )
    
    embs = acts
    acts = None
    if average_seq_dim :
        if verbose > 0: 
            ut.print_flush(f"get_enc_embs_moirai | About to reduce activations", verbose = verbose)
        embs = embs.mean(dim = 1)
    
    if not cpu:
        #ut.print_flush(f"get_enc_embs_moirai | enc_input to cpu")
        #enc_input.cpu()
        ut.print_flush(f"get_enc_embs_moirai | enc_model to cpu", verbose = verbose)
        enc_model.cpu()
        ut.print_flush(f"get_enc_embs_moirai | torch cuda empty cache", verbose = verbose)
        torch.cuda.empty_cache()
    if to_numpy: 
        if cpu > 0:
            embs = embs.numpy() 
        else: 
            embs = embs.cpu().numpy()
            torch.cuda.empty_cache()
    if verbose > 0: 
        ut.print_flush(f"get_enc_embs_moirai | embs ~ {embs.shape}", verbose = verbose)
        ut.print_flush("get_enc_embs_moirai -->", verbose = verbose)
    return embs

In [28]:
#| export 
def get_enc_embs(
    X               , 
    enc_learn       : Learner, 
    module          : str  = None, 
    cpu             : bool = False, 
    average_seq_dim : bool = True, 
    to_numpy        : bool = True,
    verbose         : int  = 0,
    **kwargs        
):
    embs = None
    enc_learn_class = str(enc_learn.__class__)[8:-2]
    match enc_learn_class:
        case "momentfm.models.moment.MOMENTPipeline":
            match enc_learn.task_name:
                case "embedding":
                    embs = get_enc_embs_moment(X, enc_learn, cpu, to_numpy, verbose, average_seq_dim, **kwargs)
                case "reconstruction":
                    embs = get_enc_embs_moment_reconstruction(X, enc_learn, cpu, to_numpy, verbose, average_seq_dim, **kwargs)
                case _:
                    ut.print_flush(f"Model embeddings for moment-{enc_learn.task_name} is not yet implemented.", verbose = verbose)
        case "fastai.learner.Learner":
            embs = get_enc_embs_MVP_set_stride_set_batch_size(X, enc_learn, stride, batch_size, module, cpu, average_seq_dim, to_numpy, verbose, False, 0, False)
        case "uni2ts.model.moirai.module.MoiraiModule":
            embs = get_enc_embs_moirai(
                enc_input  = X, 
                enc_model  = enc_learn,
                cpu        = cpu, 
                average_seq_dim = average_seq_dim,
                verbose    = verbose,
                **kwargs
            )
        case _:
            ut.print_flush(f"Model embeddings implementation is not yet implemented for {enc_learn_class}.", verbose = verbose)
    return embs

In [29]:
#| export
def get_enc_embs_set_stride_set_batch_size(
    X                  : List [ List [ List [ float ] ] ], 
    enc_learn          : Learner, 
    stride             : int, 
    batch_size         : int, 
    module             : str  = None, 
    cpu                : bool = False, 
    average_seq_dim    : bool = True, 
    to_numpy           : bool = True, 
    verbose            : int  = 0, 
    time_flag          : bool = False, 
    chunk_size         : int  = 0, 
    check_memory_usage : bool = False,
    **kwargs
):
    ut.print_flush("--> get_enc_embs_set_stride_set_batch_size", verbose = verbose)
    embs = None
    enc_learn_class = str(enc_learn.__class__)[8:-2]
    match enc_learn_class:
        case "momentfm.models.moment.MOMENTPipeline":
            if verbose > 0: 
                ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | Moment | {average_seq_dim}", verbose = verbose)
            match enc_learn.task_name:
                case "embedding":
                    embs = get_enc_embs_moment( X = X, enc_learn = enc_learn, cpu = cpu, to_numpy = to_numpy, verbose = verbose, average_seq_dim = average_seq_dim)
                case "reconstruction":
                    embs = get_enc_embs_moment_reconstruction(X= X, enc_learn = enc_learn, cpu = cpu, to_numpy = to_numpy, verbose = verbose, average_seq_dim = average_seq_dim, **kwargs)
                case _:
                    ut.print_flush(f"Model embeddings for moment-{enc_learn.task_name} is not yet implemented.", verbose = verbose)
        case "fastai.learner.Learner":
            if verbose > 0: 
                ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | MVP | {average_seq_dim}", verbose = verbose)
            if verbose > 1:
                ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | X ~{X.shape}", verbose = verbose)
            embs = get_enc_embs_MVP_set_stride_set_batch_size(
                X = X, 
                enc_learn = enc_learn, 
                stride = stride, 
                batch_size = batch_size, 
                module = module, 
                cpu = cpu, 
                average_seq_dim = average_seq_dim,
                to_numpy = to_numpy, 
                verbose = verbose, 
                time_flag = time_flag, 
                chunk_size = chunk_size, 
                check_memory_usage = check_memory_usage
            )
        case "uni2ts.model.moirai.module.MoiraiModule":
            if verbose > 0: 
                ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | Moirai | {average_seq_dim}", verbose = verbose)
            embs = get_enc_embs_moirai(
                enc_input  = X, 
                enc_model  = enc_learn,
                cpu        = cpu, 
                average_seq_dim = average_seq_dim,
                verbose    = verbose,
                to_numpy = to_numpy,
                **kwargs
            )
        case _:
            ut.print_flush(f"[ get_enc_embs_set_stride_set_batch_size ] Model embeddings implementation is not yet implemented for {enc_learn_class}.", verbose = verbose)
    # Ñapa: TODO: Gestionar que no se queden en memoria los modelos porque ocupan el 40% de la GPU al llamarlos desde R
    if verbose > 0: ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | Before moving to CPU | embs~{embs.shape}", verbose = verbose)
    if cpu:
        #X.cpu()
        enc_learn.cpu()
        try: 
            enc_lear.dls.cpu()
        except Exception as e: 
            ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | Exception: {e}", verbose = verbose)
        #kwargs_to_cpu_(**kwargs)
    if verbose > 0: ut.print_flush(f"get_enc_embs_set_stride_set_batch_size | embs~{embs.shape} -->", verbose = verbose)
    return embs

## Fine-tunning
> Take a look on [HuggingFace - Fine-tune a pretrained model](https://huggingface.co/docs/transformers/training) if not used to few-shot learning or fine-tuning models.

Steps: 

1) Prepare the dataset
2) Batch the data
   - Remember splitting between train & test dataset
   - Remember to use DataLoader to iterate over batches
4) Load the trained model and check if any modification is needed
   - Check wether any layer may be substituted by an "identity" if not needed for your case
   - Check if any dimension in a conversion layer may be changed to fit your dataset.
5) Select an optimizer from torch.optim (Adam)
6) ¿If using transformer, lr_scheduler? 
7) Training loop
     

### Utils

In [30]:
#| export
from tqdm.auto import tqdm
from transformers import get_scheduler
import evaluate
from torch.nn.modules.loss import _Loss
from tsai.data.preparation import SlidingWindow
from dvats.utils import find_dominant_window_sizes_list

In [31]:
#| export
def random_windows(
    X           : List [ List [ List [ float ]]], 
    n_windows   : int       = None, 
    percent     : float     = None, 
    mssg        : ut.Mssg   = ut.Mssg()
):
    """
    Parameters: 
    - X: Numpy array of windows. Expected shape: [batch_size or n_samples, n_vars, window_len]
    Given a numpy array of windows, selects:
    - n_windows random windows from the array, if n_windows is given.
    - ceil(percent*len(X)) random windows otherwise
    """
    mssg_ = deepcopy(mssg)
    mssg_.initial(func_name=f"{mssg.function} | {ut.funcname()}")
    mssg_.print(f"N windows: {n_windows}")
    if n_windows is None and percent is None:
        windows = torch.from_numpy(X)
    else: 
        n_windows = int(min(X.shape[0], n_windows) if n_windows is not None else np.ceil(percent*X.shape[0]))
        mssg_.print(f"n_windows: {n_windows}")
        random_indices = np.random.randint(0, int(X.shape[0]), n_windows)
        windows = X[ random_indices ]
        windows = torch.from_numpy(windows)
    mssg.print(f"windows~{windows.shape}")
    mssg_.final()
    return windows

In [None]:
#| export
def windowed_dataset(
    X                               : Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ],
    stride                          : int           = 1,
    window_sizes                    : List [int]    = None,
    n_window_sizes                  : int           = 1,
    window_sizes_offset             : int           = 0.05,
    windows_min_distance            : int           = 1,
    full_dataset                    : bool          = False,
    mssg                            : ut.Mssg       = ut.Mssg()
): 
    stride = 1 if stride is None else stride 
    n_window_sizes = 1 if n_window_sizes is None else n_window_sizes
    window_sizes_offset = 0.05 if window_sizes_offset is None else window_sizes_offset
    windows_min_distance = 1 if windows_min_distance is None else windows_min_distance
    full_dataset = False if full_dataset is None else full_dataset
    mssg = ut.Mssg() if mssg is None else mssg
    mssg.initial(ut.funcname())
    dss = []
    if isinstance(X, list):
        mssg_print("X is a list. Converting to dataFrame")
        X = np.array(X)
        X = pd.DataFrame(X)        
    if ( isinstance(X,pd.DataFrame) or full_dataset): 
        mssg.print(f"X is a DataFrame, X~{X.shape} | window_sizes {len(window_sizes) if window_sizes is not None else 0}, n_window_sizes {n_window_sizes}")
        if window_sizes is None or n_window_sizes > len(window_sizes):
            mssg.print("X is a DataFrame | Selecting Fourier's dominant frequences")
            # Select Fourier's dominant frequences
            window_sizes_ = find_dominant_window_sizes_list(
                X               = X, 
                nsizes          = n_window_sizes, 
                offset          = window_sizes_offset, 
                min_distance    = windows_min_distance,
                mssg            = mssg
            )
            window_sizes = window_sizes_ if window_sizes is None else list(set(window_sizes + window_sizes_))[:n_window_sizes]
            mssg.print(f"X is a DataFrame | Window sizes: {len(window_sizes)}", func_name = ut.funcname())
        mssg.print(f"Building the windows")
        for w in window_sizes:
            mssg.print(f"w = {w}", verbose_level = mssg.level+1)
            enc_input, _ = SlidingWindow(window_len = w, stride = stride, get_y=[])(X)
            dss.append(enc_input)
            mssg.print(f"w {w} | enc_input~{enc_input.shape} | dss~{len(dss)}",  verbose_level = mssg.level+1)
    else: 
        mssg.print("X is already windowed")
        dss = [X]
    mssg.print(f"Number of windows: {len(dss)}")
    mssg.final()
    return dss


In [33]:
#| export
def setup_scheduler(
    dl_train                        : DataLoader,
    lr_scheduler_flag               : bool= False,
    lr_scheduler_name               : str = "",
    optimizer                             = None,
    num_epochs                      : int = 10,
    lr_scheduler_num_warmup_steps   : int = None,
    num_training_steps              : int = None,
    lr_scheduler_perc_warmup_steps  : int = 0.02,
    lr_scheduler_max_lr             : float = None,
    lr                              : float = 1e-4
):
    num_training_steps = num_epochs * len(dl_train) if num_training_steps is None else num_training_steps
    lr_scheduler_num_warmup_steps = lr_scheduler_perc_warmup_steps*num_training_steps
    lr_scheduler_max_lr = 5 - 10 * lr if lr_scheduler_max_lr is None else lr_scheduler_max_lr
    if lr_scheduler_flag:
        match lr_scheduler_name:
            case "OneCycleLR": 
                lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
                    optimizer           = optimizer,
                    max_lr              = lr_scheduler_max_lr,
                    epochs              = num_epochs,
                    steps_per_epoch     = len(dl_train)
                )
            case _:
                lr_scheduler = get_scheduler(
                    name                = lr_scheduler_name,
                    optimizer           = optimizer,
                    num_warmup_steps    = lr_scheduler_num_warmup_steps,
                    num_training_steps  = num_training_steps
                )
    return lr_scheduler

In [34]:
#| export
def prepare_train_and_eval_dataloaders(
    X                   : Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ],
    batch_size          : int,
    n_windows           : int       = None,
    n_windows_percent   : int       = None,
    training_percent    : int       = 0.4,
    validation_percent  : int       = 0.3,
    shot                : bool      = False,
    eval_pre            : bool      = False,
    eval_post           : bool      = False,
    mssg                : ut.Mssg   = ut.Mssg()
):
    dl_eval  = None
    ds_train = None,
    dl_train = None
    mssg.function = f"{mssg.function} | prepare_train_and_eval_dataloaders"
    if n_windows is None and n_windows_percent is None:
        train_split_index = min(X.shape[0], np.ceil(training_percent * X.shape[0]))
        eval_split_index = min(X.shape[0], np.ceil(validation_percent * X.shape[0]))
    else:
        train_split_index = min(X.shape[0], np.ceil(training_percent * n_windows)) if n_windows is not None else np.ceil(training_percent * n_windows_percent * X.shape[0])
        eval_split_index = min(X.shape[0], np.ceil(validation_percent * n_windows)) if n_windows is not None else np.ceil(validation_percent * n_windows_percent * X.shape[0])
    
    train_split_index = int(train_split_index)
    eval_split_index = int(eval_split_index)
    if shot: 
        mssg.print(f"Selecting ds train | {train_split_index} windows")
        ds_train = X[:train_split_index]
    if eval_pre or eval_post: 
        mssg.print(f"Selecting validation train | {eval_split_index} windows")
        ds_test  = torch.from_numpy(X[:eval_split_index]).float()
    # -- Select only the small percentage for few-shot
    if shot:
        mssg.print(f"Train DataLoader | Random windows")
        mssg.verbose -= 1
        ds_train = random_windows(ds_train, n_windows, n_windows_percent, mssg = mssg)
        mssg.verbose += 1
        ds_train = ds_train.float()
        # Create the dataloader
        mssg.print(f"Train DataLoader | DataLoader")
        dl_train = DataLoader(ds_train, batch_size = batch_size, shuffle = True)
    if eval_pre or eval_post: 
        mssg.print(f"Validation DataLoader")
        dl_eval  = DataLoader(ds_test, batch_size = batch_size, shuffle = False)
    return dl_eval, dl_train, ds_train

### Moment
> Follow the tutorial in the original repository: [Moment - Imputation](https://github.com/moment-timeseries-foundation-model/moment/blob/main/tutorials/imputation.ipynb).

In [35]:
#| export
from momentfm.utils.masking import Masking

In [36]:
#| export
def fine_tune_moment_compute_loss_check_sizes_(
    batch           : List [ List [ List [ float ] ] ], 
    output, 
    verbose         : int   = 0,
    # Print options
    print_to_path   : bool  = False,
    print_path      : str   = "~/data/logs/logs.txt",
    print_mode      : str   = 'a'
):
    if verbose > 0: ut.print_flush("--> fine_tune_moment_compute_loss_check_sizes_", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    b = batch.clone()
    b_2 = batch.shape[2]
    re_2 = output.reconstruction.shape[2]
    if b_2 > re_2:
        if verbose > 0: ut.print_flush(f" Fine tune loop | TODO: Why? Original {b_2} > {re_2}  Reconstruction", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        b = b[...,:re_2]
    elif re_2 > b_2:
        if verbose > 1: ut.print_flush(f" Fine tune loop | Why ? Original {b_2} < {re_2} Reconstruction ? Padding", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        output.reconstruction = output.reconstruction[...,:b_2]
    else: 
        if verbose > 1: ut.print_flush(f" Fine tune loop | re_2 {re_2} == {b_2} y_2", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    if verbose > 1: 
        ut.print_flush(f"---------- Checking loss  ------- | reconstruction ~ {output.reconstruction.shape} | original_ ~ {b.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    if verbose > 0: ut.print_flush("fine_tune_moment_compute_loss_check_sizes_ -->", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    return b

In [37]:
#| export
def fine_tune_moment_compute_loss(
    batch, 
    output, 
    criterion   = torch.nn.MSELoss, 
    verbose     = 0, 
    input_mask  = None, 
    mask        = None,
    # Print options
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a'
):
    if verbose > 0: ut.print_flush("--> fine_tune_moment_compute_loss", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    b = fine_tune_moment_compute_loss_check_sizes_(batch = batch, output = output, verbose = verbose, print_to_path = print_to_path, print_path = print_path, print_mode = 'a')
    if verbose > 0: ut.print_flush(f"fine_tune_moment_compute_loss | b~{b.shape} | o~{output.reconstruction.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    o = output.reconstruction
    device = b.device if b.device != "cpu" else o.device
    b = b.to(device)
    o = o.to(device)
    compute_loss = criterion()
    recon_loss = compute_loss(o, b)
    batch_masks = output.input_mask if input_mask is None else input_mask
    mask = output.pretrain_mask if mask is None else mask
    batch_masks = batch_masks.to(device)
    mask = mask.to(device)
    if verbose > 1: ut.print_flush(f"fine_tune_moment_compute_loss | batch ~ {b.shape} | {b.device}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    if verbose > 1: ut.print_flush(f"fine_tune_moment_compute_loss | batch_masks ~ {batch_masks.shape} | {batch_masks.device}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    if verbose > 1: ut.print_flush(f"fine_tune_moment_compute_loss | mask ~ {mask.shape} | {mask.device}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    
    observed_mask = batch_masks * (1-mask)
    masked_loss = observed_mask * recon_loss
    loss = masked_loss.nansum() / (observed_mask.nansum() + 1e-7)
    if verbose > 2: ut.print_flush(f"Loss type: {type(loss)}",print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)  # Debe ser <class 'torch.Tensor'>
    if verbose > 1: ut.print_flush(f"fine_tune_moment_compute_loss | loss: {loss.item()}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    if verbose > 0: ut.print_flush("fine_tune_moment_compute_loss -->", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    return loss

In [38]:
#| export
def fine_tune_moment_eval_preprocess(
    predictions : List [ List [ float ]],
    references : List [ List [ float ]],
    verbose : int = 0,
    # Print options
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a'
):
    """
    Parameters:
    - predictions torch (float)
    - references torch (float)
    Returns: 
        - Predictions and references ensuring same shape and no NaN values. 
        - Uses the shape of the smallest torch for the modification.
    """
    if verbose > 0: 
        ut.print_flush(f"fine_tune_moment_eval | Before reshape | preds~{predictions.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)            
        ut.print_flush(f"fine_tune_moment_eval | Before reshape | refs~{references.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    predictions = einops.rearrange(predictions, "b v w -> (b v) w")
    references = einops.rearrange(references, "b v w -> (b v) w")
    # Avoid NaN 
    if predictions.shape[1] > references.shape[1]: predictions = predictions[:,:references.shape[1]]
    if predictions.shape[1] < references.shape[1]: references = references[:,:predictions.shape[1]]
    if verbose > 0: 
        ut.print_flush(f"Eval | After reshape | preds~{predictions.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        ut.print_flush(f"Eval | After reshape | refs~{references.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        
    nan_mask = torch.isnan(predictions) | torch.isnan(references)
    predictions = torch.where(nan_mask, torch.tensor(0.0), predictions)
    references = torch.where(nan_mask, torch.tensor(0.0), references)
    if verbose > 0: 
        ut.print_flush(f"Eval | After NaN | preds~{predictions.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
        ut.print_flush(f"Eval | After NaN | refs~{references.shape}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)
    return predictions, references

In [39]:
#| export
def fine_tune_moment_eval_step_(
    enc_learn : Learner,
    batch,
    mse_metric, 
    rmse_metric,
    mae_metric,
    smape_metric,
    cpu             : bool = False,
    verbose         : int = 0,
    # Print options
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a'
):
    with torch.no_grad():
        output, enc_learn = sure_eval_moment(
            enc_learn = enc_learn, 
            cpu = cpu,
            verbose = verbose,                     
            y = batch, 
            input_mask = None,
            mask = None,
            padd_step = 100, 
            max_trials = 5, 
            acts_indices = None,
            print_to_path = print_to_path, print_path = print_path, print_mode = print_mode
        )
        predictions = output.reconstruction
        references = batch
        predictions = predictions.to(device)
        references = references.to(device)
        predictions, references = fine_tune_moment_eval_preprocess(predictions = predictions, references = references, verbose = verbose, print_to_path = print_to_path, print_path = print_path, print_mode = print_mode)
        mse_metric.add_batch(predictions=predictions, references = references)
        rmse_metric.add_batch(predictions=predictions, references = references)
        mae_metric.add_batch(predictions=predictions, references = references)
        smape_metric.add_batch(predictions=predictions, references = references)
        return mse_metric, rmse_metric, mae_metric, smape_metric

In [40]:
#| export
def fine_tune_moment_eval_(
    enc_learn : Learner,
    dl_eval   : DataLoader,
    num_epochs: int = 1,
    cpu       : bool = False,
    verbose   : int = 0,
    # Print options
    print_to_path   : bool          = False,
    print_path      : str           = "~/data/logs/logs.txt",
    print_mode      : str           = 'a'
):
    # Select device
    device = "cpu" if cpu else torch.cuda.current_device()
    # Load metrics
    mse_metric = evaluate.load('mse', "multilist")
    rmse_metric = evaluate.load('mse', "multilist")
    mae_metric = evaluate.load('mae', "multilist")
    smape_metric = evaluate.load("smape", "multilist")
    num_evaluation_steps = len(dl_eval)
    enc_learn = enc_learn.to(device)
    enc_learn.eval()
    #if print_to_path:
    #    pf = open(os.path.expanduser(print_path + "_progress"), "w")
    #    progress_bar = tqdm(range(num_evaluation_steps), file = pf)
    #    # Predict evaluation dataset
    #    for batch in dl_eval:
    #        batch = batch.to(device)
    #        mse_metric, rmse_metric, mae_metric, smape_metric = fine_tune_moment_eval_step_(
    #            enc_learn = enc_learn, 
    #            batch = batch, 
    #            mse_metric = mse_metric, 
    #            rmse_metric = rmse_metric,
    #            mae_metric = mae_metric,
    #            smape_metric = smape_metric,
    #        )
    #        progress_bar.update(1)
    #    progress_bar.close()
    #else:
    progress_bar = tqdm(range(num_evaluation_steps))
    for batch in dl_eval:
        batch = batch.to(device)
        mse_metric, rmse_metric, mae_metric, smape_metric = fine_tune_moment_eval_step_(
            enc_learn = enc_learn, 
            batch = batch, 
            mse_metric = mse_metric, 
            rmse_metric = rmse_metric,
            mae_metric = mae_metric,
            smape_metric = smape_metric,
        )
        progress_bar.update(1)
    progress_bar.close()
    mse   = mse_metric.compute(squared = False)
    rmse  = rmse_metric.compute(squared = True)
    mae   = mae_metric.compute()
    smape = smape_metric.compute()
    eval_results = {
        "mse": mse,
        "rmse": rmse,
        "mae": mae,
        "smape": smape
    }
    enc_learn.train()
    return eval_results

In [41]:
#| export
def fine_tune_moment_train_loop_step_(
    enc_learn,
    batch, 
    batch_masks,
    criterion                               = torch.nn.MSELoss, 
    window_mask_percent             : float = 0.3,
    cpu                             : bool  = False,
    verbose                         : int   = 0,
    print_to_path                   : bool  = False,
    print_path                      : str   = "~/data/logs/logs.txt",
    print_mode                      : str   = 'a',
    use_moment_masks                : bool  = False,
    mask_stateful                   : bool  = False,
    mask_future                     : bool  = False,
    mask_sync                       : bool  = False
): 
    device = torch.cuda.current_device() if not cpu else "cpu"
    bms = batch_masks
    if use_moment_masks:
        mask_generator = Masking(mask_ratio = window_mask_percent)
    
    if batch.shape[0] < batch_masks.shape[0]:  
        bms = batch_masks[:batch.shape[0]]
    if verbose > 1: 
        ut.print_flush(
            f"fine_tune_moment_train_loop_step_ | Fine tune loop | batch ~ {batch.shape} | batch_masks ~ {bms.shape}",
            print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose
        )
    
    batch   = batch.to(device)
    bms     = bms.to(device) 

    if bms.shape[0] > batch.shape[0]: bms = bms[:batch.shape[0]]
    if verbose > 0: 
        ut.print_flush(
            f"fine_tune_moment_train_loop_step_ | Fine tune loop | window_mask_percent {window_mask_percent} | batch ~ {batch.shape}",
            print_to_path=print_to_path, print_path=print_path,
            print_mode = 'a', verbose = verbose
        )
    if use_moment_masks:
        mask = mask_generator.generate_mask(
            x = batch,
            input_mask = bms
        )
    else: 
        o   = torch.zeros(batch.shape[0], batch.shape[2])
        if verbose > 0: 
            ut.print_flush(
                f"fine_tune_moment_train_loop_step_ | Fine tune loop | o ~ {o.shape} | stateful = {mask_stateful} | sync = {mask_sync} | r = {window_mask_percent}",
                print_to_path=print_to_path, print_path=print_path,
                print_mode = 'a', verbose = verbose
            )
        if mask_future:
            mask = create_future_mask(
                o       = o, 
                r       = window_mask_percent, 
                sync    = mask_sync
            )[0,:,:].int() # As there is only 1 variable/variables are flattened, an extra dim is created by the masking function
        else:
            mask = create_subsequence_mask(
                o       = o,
                r       = window_mask_percent,
                stateful= mask_stateful,
                sync    = mask_sync
            )[0,:,:].int() # As there is only 1 variable/variables are flattened, an extra dim is created by the masking function
        if verbose > 0:
            ut.print_flush(
                f"fine_tune_moment_train_loop_step_ | Fine tune loop | Before shape adjustment | batch ~ {batch.shape} | batch_masks ~ {bms.shape} | mask ~ {mask.shape}",
                print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose
            )

    if mask.shape[0] < bms.shape[0]:  bms = batch_masks[:mask.shape[0]]
    if mask.shape[1]  < batch_masks.shape[1] :
        mask = torch.nn.functional.pad(mask,(0,batch_masks.shape[1]-mask.shape[1]))
    
    batch = batch.to(device)
    mask = mask.to(device)
    bms = bms.to(device)
    #ut.print_flush(f"fine_tune_moment_train_loop_step_ | Enc_learn Before sure_eval_moment {enc_learn.__class__}")
    enc_learn = enc_learn.to(device)
    if verbose > 1: 
        ut.print_flush(
            f"fine_tune_moment_train_loop_step_ | Fine tune loop | batch ~ {batch.shape} | batch_masks ~ {bms.shape} | mask ~ {mask.shape}",
            print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose
        )
    for param in enc_learn.parameters():
        param = param.to(device)
    if verbose > 1: 
        ut.print_flush(
            f"fine_tune_moment_train_loop_step_ | sure_eval_moment | b{batch.device} | m{mask.device} | bm{bms.device}",
            print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose
        )
    output, enc_learn = sure_eval_moment(
        enc_learn = enc_learn, 
        cpu = cpu,
        verbose = verbose, 
        y = batch, 
        input_mask = bms, # None
        mask = mask, # None
        padd_step = 100, 
        max_trials = 5, 
        acts_indices = None,
        print_to_path = print_to_path, print_path = print_path, print_mode = 'a',
        continue_if_fail = True
    )
    #ut.print_flush(f"fine_tune_moment_train_loop_step_ | Enc_learn After sure_eval_moment {enc_learn.__class__}")
    # Compute output loss
    if output is None:
        ut.print_flush(
            f"fine_tune_moment_train_loop_step_ | Execution failed | Output none ",
            print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose
        )
        loss = 0
    else: 
        loss = fine_tune_moment_compute_loss(batch, output, criterion, verbose = verbose, input_mask = bms, mask = mask, print_to_path = print_to_path, print_path = print_path, print_mode = 'a')
        #ut.print_flush(f"fine_tune_moment_train_loop_step_ | Enc_learn After compute loss {enc_learn.__class__} | -->")
    return loss, enc_learn

In [42]:
#| export
def fine_tune_moment_train_(
    enc_learn                       : Learner, 
    dl_train                        : DataLoader,
    ds_train                        : pd.DataFrame,
    window_mask_percent             : float = 0.3,
    batch_size                      : int   = 1,
    num_epochs                      : int   = 1,
    criterion                               = torch.nn.MSELoss, 
    optimizer                               = None, 
    lr                              : float = 5e-5,  #1 e -4
    lr_scheduler_flag               : bool  = False, 
    lr_scheduler_name               : str   = "linear",
    lr_scheduler_num_warmup_steps   : int   = None,
    cpu                             : bool  = False,
    verbose                         : int   = 0,
    print_to_path                   : bool  = False,
    print_path                      : str   = "~/data/logs/logs.txt",
    print_mode                      : str   = 'a',
    use_moment_masks                : bool  = False,
    mask_stateful                   : bool  = False,
    mask_future                     : bool  = False,
    mask_sync                       : bool  = False
):
    # Select device
    device = "cpu" if cpu else torch.cuda.current_device()
    # Optimizer and learning rate scheduler
    if optimizer is None: 
        optimizer = torch.optim.AdamW(enc_learn.parameters(), lr)
    num_training_steps = num_epochs * len(dl_train)
    losses = []
    if lr_scheduler_flag:
        lr_scheduler = setup_scheduler(
            dl_train=dl_train, lr_scheduler_flag=lr_scheduler_flag, lr_scheduler_name=lr_scheduler_name,
            optimizer=optimizer, num_epochs = num_epochs, lr_scheduler_num_warmup_steps = lr_scheduler_num_warmup_steps, 
            num_training_steps= num_training_steps, lr = lr
        )
    else:
        lr_scheduler = None
        
    # Training loop
    if verbose > 1: ut.print_flush("fine_tune_moment_train_ | Training loop", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    # Masks
    n_samples, n_channels, window_size = ds_train.shape
    batch_masks = torch.ones(
        (batch_size, window_size), 
        device = device
    ).long()
    if verbose > 1: ut.print_flush(f"fine_tune_moment_train | Fine tune loop | print_to_path {print_to_path} | batch_masks~{batch_masks}", print_to_path = print_to_path, print_path = print_path, print_mode = 'a', verbose = verbose, print_time = print_to_path)     
    progress_bar = tqdm(range(num_training_steps))
    if verbose > 0: ut.print_flush(f"fine_tune_moment_train | num_epochs {num_epochs} | n_batches {len(dl_train)}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    for epoch in range(num_epochs):
        for i, batch in enumerate(dl_train):
            if verbose > 0: 
                #ut.print_flush(f"fine_tune_moment_train | batch {i} ~ {batch.shape} | epoch {epoch} | train {i+epoch} of {num_training_steps} | Before loop step | Enc_learn {enc_learn.__class__}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path, print_both = False)
                ut.print_flush(f"fine_tune_moment_train | batch {i} ~ {batch.shape} | epoch {epoch} | train {i+epoch} of {num_training_steps} | Before loop step", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path, print_both = False)
            loss, enc_learn = fine_tune_moment_train_loop_step_(
                    enc_learn                       = enc_learn,
                    batch                           = batch,
                    batch_masks                     = batch_masks, 
                    window_mask_percent             = window_mask_percent,
                    verbose                         = verbose,
                    print_to_path                   = print_to_path,
                    print_mode                      = 'a',
                    use_moment_masks                = use_moment_masks,
                    mask_stateful                   = mask_stateful,
                    mask_future                     = mask_future,
                    mask_sync                       = mask_sync
                )
            try: 
                if verbose > 0: ut.print_flush(
                    #f"fine_tune_moment_train | batch {i} ~ {batch.shape} | epoch {epoch} | train {i+epoch} of {num_training_steps} | Loss backward | After loop step | Enc_learn {enc_learn.__class__}", 
                    f"fine_tune_moment_train | batch {i} ~ {batch.shape} | epoch {epoch} | train {i+epoch} of {num_training_steps} | Loss backward | After loop step ", 
                    print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path, print_both = False
                )
                if isinstance(loss, int):
                    losses.append(loss)    
                else:
                    losses.append(loss.item())
                    loss.backward()
                optimizer.zero_grad()  
                optimizer.step()
            except Exception as e: 
                ut.print_flush(f"fine_tune_moment_train | batch {i} ~ {batch.shape} | epoch {epoch} | train {i+epoch} of {num_training_steps} | Loss backward failed: {e}", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
                if isinstance(loss, int):
                    losses.append(loss)                    
                else:
                    losses.append(np.nan)
                optimizer.zero_grad()
                optimizer.step()
            
            if lr_scheduler_flag: lr_scheduler.step()
            progress_bar.update(1)
    progress_bar.close()
    if verbose > 0:
        #ut.print_flush(f"fine_tune_moment_train | enc_learn {enc_learn.__class__} | -->", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
        ut.print_flush(f"fine_tune_moment_train | -->", print_to_path = print_to_path, print_path = print_path, print_mode = print_mode, verbose = verbose, print_time = print_to_path)
    return losses, enc_learn

In [43]:
#| export
def fine_tune_moment_single_(
    self                : Encoder,
    eval_pre            : bool = False,
    eval_post           : bool = False,
    shot                : bool = True,
    sample_id           : int  = 0,
    use_moment_masks    : bool = False
):
    self.mssg.initial_("fine_tune_moment_single")
    t_shot              = 0
    t_eval_1            = 0
    t_eval_2            = 0
    losses              = []
    eval_results_pre    = ""
    eval_results_post   = ""

    if self.time_flag: timer = ut.Time(mssg = self.mssg)
    self.mssg.print(f"fine_tune_moment_single | Prepare the dataset | X ~ {self.input.data[sample_id].shape}")
    # Prepare the dataset
    dl_eval, dl_train, ds_train = prepare_train_and_eval_dataloaders(
        X                   = self.input.data[sample_id], 
        batch_size          = self.input.batch_size, 
        n_windows           = self.input.n_windows, 
        n_windows_percent   = self.input.n_windows_percent,
        training_percent    = self.input.training_percent, 
        validation_percent  = self.input.validation_percent, 
        shot                = shot, 
        eval_pre            = eval_pre, 
        eval_post           = eval_post,
        mssg                = deepcopy(self.mssg)
    )
    if eval_pre:
        self.mssg.print(f"fine_tune_moment_single | Eval Pre | wlen {self.input.data[sample_id].shape[2]}")
        if self.time_flag: timer.start()
        eval_results_pre    = fine_tune_moment_eval_(
            enc_learn       = self.model,
            dl_eval         = dl_eval,
            num_epochs      = self.num_epochs,
            cpu             = self.cpu,
            verbose         = self.mssg.verbose-1,
            print_to_path   = self.mssg.to_path, 
            print_path      = self.mssg.path, 
            print_mode      = self.mssg.mode
        )
        if self.time_flag: 
            timer.end()
            t_eval_1 = timer.duration()
            timer.show(verbose = self.mssg.verbose)
    if shot:
        if self.time_flag: timer.start()
        self.mssg.print(f"fine_tune_moment_single | Train | wlen {self.input.data[sample_id].shape[2]}")
        try:
            if self.time_flag: timer.start()
            losses, self.model                  = fine_tune_moment_train_(
                enc_learn                       = self.model,
                dl_train                        = dl_train,
                ds_train                        = ds_train,
                window_mask_percent             = self.input.window_mask_percent,
                batch_size                      = self.input.batch_size,
                num_epochs                      = self.num_epochs,
                criterion                       = self.optim.criterion, 
                optimizer                       = self.optim.optimizer, 
                lr                              = self.optim.lr.lr      if isinstance(self.optim.lr, LRScheduler) else self.optim.lr, 
                lr_scheduler_flag               = self.optim.lr.flag    if isinstance(self.optim.lr, LRScheduler) else False, 
                lr_scheduler_name               = self.optim.lr.name    if isinstance(self.optim.lr, LRScheduler) else False,
                lr_scheduler_num_warmup_steps   = self.optim.lr.num_warmup_steps if isinstance(self.optim.lr, LRScheduler) else 0,
                cpu                             = self.cpu,
                verbose                         = self.mssg.verbose-1,
                print_to_path                   = self.mssg.to_path, 
                print_path                      = self.mssg.path, 
                print_mode                      = self.mssg.mode,
                use_moment_masks                = use_moment_masks,
                mask_stateful                   = self.mask_stateful,
                mask_future                     = self.mask_future,
                mask_sync                       = self.mask_sync
            )
            if self.time_flag:
                timer.end()
                t_shot = timer.duration()
                timer.show()
        except Exception as e:
            self.mssg.print(f"fine_tune_moment_single | Train | Window {self.input.shape[2]} not valid | {e}")
            traceback.print_exc()
    if eval_post:    
        self.mssg.print(f"fine_tune_moment_single | Eval Post | wlen {self.input.shape[2]}")
        if self.time_flag: timer.start()
        eval_results_post = fine_tune_moment_eval_(
            enc_learn       = self.model,
            dl_eval         = dl_eval,
            num_epochs      = self.num_epochs,
            cpu             = self.cpu,
            verbose         = self.mssg.verbose-1,
            print_to_path   = self.mssg.to_path, 
            print_path      = self.mssg.path, 
            print_mode      = 'a'
        )
        if self.time_flag:
            timer.end()
            t_eval_2 = timer.duration()
            if self.mssg.verbose > 0: 
                timer.show()
            if self.mssg.verbose > 0: 
                self.show_eval_stats(
                    # Wether computed or not pre & post errors
                    eval_pre        = eval_pre, 
                    eval_post       = eval_post, 
                    # Results
                    eval_stats_pre  = eval_results_pre,
                    eval_stats_post = eval_results_post,
                    # Function name
                    func_name       = ut.funcname()
                )
    self.mssg.final(ut.funcname())
    return losses, eval_results_pre, eval_results_post, t_shot, t_eval_1, t_eval_2, self.model

Encoder.fine_tune_moment_single_ = fine_tune_moment_single_

In [None]:
#| export
def fine_tune_moment_(
        self                : Encoder, 
        eval_pre            : bool = False, 
        eval_post           : bool = False, 
        shot                : bool = False,
        time_flag           : bool = None,
        use_moment_masks    : bool = None
):   
    self.mssg.initial(ut.funcname())
    self.time_flag = self.time_flag if time_flag is None else time_flag
    self.use_moment_masks = self.use_moment_masks if use_moment_masks is None else use_moment_masks
    # Return values
    lossess             = []
    eval_results_pre    = []
    eval_results_post   = []
    t_shots             = []
    t_shot              = 0
    t_evals             = []
    t_eval              = 0
    if self.input.size is None:
        self.mssg.print(f"Windows: {len(self.input._data)}")
        raise ValueError(f"Invalid number of windows: {self.input.size}")
    self.mssg.print(f"Processing {self.input.size} datasets : {self.input.shape}")
    # Build optimizer
    if self.optim.optimizer is None: 
        self.mssg.print(f"Setting up optimizer as AdamW")
        self.optim.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.optim.lr.lr)
    # Compute model for each window in the windowed dataset
    for i in range(self.input.size):
        self.mssg.print(f"Processing wlen {self.input.shape[2]}")
        ( 
            losses, eval_results_pre_, eval_results_post_, t_shot_, t_eval_1, t_eval_2, self.model
        ) =  self.fine_tune_moment_single_(eval_pre, eval_post, shot, i, use_moment_masks)
        lossess.append(losses)
        if (eval_pre): eval_results_pre = eval_results_pre_
        eval_results_post.append(eval_results_post_)
        t_shots.append(t_shot_)
        if eval_pre: t_evals.append(t_eval_1)
        if eval_post: t_evals.append(t_eval_2)
        eval_pre = False
    t_shot = sum(t_shots)
    t_eval = sum(t_evals)
    self.mssg.final(ut.funcname())
    return lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval, self.model

Encoder.fine_tune_moment_ = fine_tune_moment_

### MVP

In [None]:
#| export
from fastai.metrics import mae
def rmse(preds, targets):
    res = torch.sqrt(torch.nn.functional.mse_loss(preds, targets))
    return res

def smape(preds, targets):
    res = 100 * torch.mean(2 * torch.abs(preds - targets) / (torch.abs(preds) + torch.abs(targets)))
    return res

def rmse_flat(preds, targets):
    """
    Computes RMSE while flattening the tensors to ensure compatibility with MSELossFlat.
    """
    preds, targets = preds.view(-1), targets.view(-1)  # Flatten tensors
    return torch.sqrt(torch.nn.functional.mse_loss(preds, targets))

def smape_flat(preds, targets):
    """
    Computes SMAPE while flattening the tensors to ensure compatibility with MSELossFlat.
    """
    preds, targets = preds.view(-1), targets.view(-1)  # Flatten tensors
    denominator = (torch.abs(preds) + torch.abs(targets))
    return 100 * torch.mean(2 * torch.abs(preds - targets) / torch.clamp(denominator, min=1e-7))
def mae_flat(preds, targets):
    """
    Computes Mean Absolute Error (MAE) while flattening the tensors to ensure compatibility.
    """
    preds, targets = preds.view(-1), targets.view(-1)  # Flatten tensors
    return torch.mean(torch.abs(preds - targets))

def mse_loss_flat(preds, targets):
    """
    Computes Mean Squared Error (MSE) while flattening the tensors to ensure compatibility.
    """
    preds, targets = preds.view(-1), targets.view(-1)  # Flatten tensors
    return torch.mean((preds - targets) ** 2)

In [None]:
#| export
from fastai.losses import BaseLoss
from fastai.losses import MSELossFlat
from fastai.losses import L1LossFlat

class RMSELoss(_Loss):
    __constants__ = ["reduction"]
    def __init__(self, size_average = None, reduce = None, reduction: str = "mean") -> None:
        super().__init__(size_average, reduce, reduction)
    
    def forward(self, input: Tensor, target:Tensor) -> Tensor:
        return torch.nn.functional.mse_loss(input, target, reduction = self.reduction)

@use_kwargs_dict(reduction='mean')
def RMSELossFlat(
    *args,
    axis:int = -1,
    floatify: bool = True, 
    **kwargs
):
    "Computes RMSE with flattening, similar to MSELossFlat."
    return BaseLoss(RMSELoss, *args, axis = axis, floatify = floatify, is_2d = False, **kwargs)

class SMAPELoss(_Loss):
    __constants__ = ["reduction"]
    
    def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
        """
        Initializes the SMAPE Loss.
        
        Args:
            size_average (bool, optional): Deprecated (use reduction).
            reduce (bool, optional): Deprecated (use reduction).
            reduction (str): Specifies the reduction to apply to the output ('none', 'mean', 'sum').
        """
        super().__init__(size_average, reduce, reduction)
    
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        """
        Computes the SMAPE loss.
        
        Args:
            input (Tensor): Predicted values.
            target (Tensor): Ground truth values.
        
        Returns:
            Tensor: Computed SMAPE loss.
        """
        return self.smape_loss(input, target)
    
    @staticmethod
    def smape_loss(preds: Tensor, targets: Tensor) -> Tensor:
        """
        Computes the SMAPE loss for the given predictions and targets.
        
        Args:
            preds (Tensor): Predicted values.
            targets (Tensor): Ground truth values.
        
        Returns:
            Tensor: SMAPE loss.
        """
        denominator = (torch.abs(preds) + torch.abs(targets))
        smape = 100 * torch.mean(2 * torch.abs(preds - targets) / torch.clamp(denominator, min=1e-7))
        return smape


@use_kwargs_dict(reduction="mean")
def SMAPELossFlat(
    *args,
    axis: int = -1,
    floatify: bool = True,
    **kwargs
):
    """
    Computes SMAPE with flattening, similar to MSELossFlat.
    
    Args:
        axis (int): Axis to flatten. Default is -1.
        floatify (bool): Convert target to float. Default is True.
        **kwargs: Additional arguments.
    """
    return BaseLoss(SMAPELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

# Class alias for clarity
MAELossFlat = L1LossFlat

In [None]:
#| export
#TODO: Check. Adding lr_scheduler & optimizer to mvp
from fastai.callback.core import Callback

class CustomOptimizerCallback(Callback):
    def __init__(self, optimizer, scheduler):
        self.optimizer = optimizer
        self.scheduler = scheduler

    def before_fit(self):
        # Reemplazar el optimizador de FastAI con el personalizado
        self.learn.opt = self.optimizer

    def after_batch(self):
        # Actualizar el scheduler después de cada batch
        if self.scheduler is not None:
            self.scheduler.step()

    def after_fit(self):
        # Restaurar el optimizador original si es necesario
        del self.learn.opt

In [None]:
#| export
def validate_with_metrics(learner, metrics):
    results = []
    for metric in metrics:
        learner.crit = metric
        result = learner.validate()
        results.append(result.item() if hasattr(result, 'item') else result)
    learner.crit=MSELossFlat
    return results

In [None]:
#| export
def mvp_format_results(results):
    return {
        "mse"   : results[0],
        "rmse"  : results[1],
        "mae"   : results[2],
        "smape" : results[3]
    }

In [51]:
#| export
def fine_tune_mvp_single_(
    self            : Encoder,
    eval_pre        : bool  = False,
    eval_post       : bool  = False,
    shot            : bool  = False,
    show_plot       : bool  = False,
    sample_id       : int   = 0
):
    self.show_plot = self.show_plot if show_plot is None else show_plot
    t_shot = 0
    t_eval_1 = 0
    t_eval_2 = 0
    losses = []
    eval_results_pre = "",
    eval_results_post = ""
    if self.time_flag : timer = ut.Time(mssg = self.mssg)
    self.mssg.initial("fine_tune_mvp_single_")   
    X = self.get_splits_(sample_id)
    self.mssg.print("About to set callbacks", func_name = ut.funcname())
    cbs = L(WandbCallback(log_preds=False)) if self.use_wandb else L()
    cbs2 = [
        EarlyStoppingCallback(
            monitor='valid_loss', 
            min_delta=0.000001, 
            patience=10
        ),
        #SaveModelCallback(
        #    monitor = 'valid_loss', 
        #    fname = 'best_model'
        #),
    ]
    

    self.mssg.print("About to set batch tfms")
    tfms = [ToFloat(), None]
    batch_tfms = [
        TSStandardize(
            by_sample       = self.norm_by_sample, 
            use_single_batch= self.norm_use_single_batch
        )
    ]
    dls = get_ts_dls(X, splits = self.splits, tfms = tfms, bs = self.input.batch_size, batch_tfms = batch_tfms)

    ### 
    # Optimizer  ### TODO: CHECK
    #### if not ( isinstance(self.optim.lr, float) or isinstance(self.optim.lr, int)):
    ####     if self.optim.lr.flag:
    ####         scheduler = setup_scheduler(
    ####             dl_train = dls.train,
    ####             lr_scheduler_flag = True,
    ####             lr_scheduler_name = self.optim.lr.name,
    ####             optimizer         = self.optim.optimizer,
    ####             num_epochs        = self.num_epochs,
    ####             lr_scheduler_num_warmup_steps = self.optim.lr.num_warmup_steps,
    ####             lr_scheduler_max_lr = None, #TODO: Think
    ####             lr                  = self.optim.lr.lr
    ####         )
    ####         custom_opt_cb = CustomOptimizerCallback(optimizer = self.optim.optimizer, scheduler = scheduler)
    ####         cbs2 += [custom_opt_cb]
    ###
    if self.show_plot: 
        self.mssg.print("Show plot")
        display(dls.show_at(0))
        sgc = ShowGraphCallback2()
        self.model = ts_learner(
            dls, 
            InceptionTimePlus,
            cbs = cbs + sgc + MVP(
                r           = self.optim.lr if isinstance(self.optim.lr, float) else self.optim.lr.lr,
                window_size = X.shape[2]-1,
                future_mask = self.mask_future,
                target_dir  = './models',
                sync        = self.mask_sync,
                stateful    = self.mask_stateful,
                fname       = f'encoder_MVP'
            ),
            y_range = [X.min(), X.max()]
        )
    else:
        self.mssg.print("Don't show plot")
        self.model = ts_learner(
            dls, 
            InceptionTimePlus,
            cbs = cbs + MVP(
                r           = self.optim.lr if ( isinstance(self.optim.lr, float) or isinstance(self.optim.lr, int)) else self.optim.lr.lr,
                window_size = X.shape[2]-1,
                future_mask = self.mask_future,
                target_dir  = './models',
                sync        = self.mask_sync,
                stateful    = self.mask_stateful,
                fname       = f'encoder_MVP'
            ),
            y_range = [X.min(), X.max()]
            #metrics = [torch.nn.functional.mse_loss, rmse, mae, smape]
        )
        self.mssg.print(f"Model Class {self.model.__class__} | Type: {type(self.model)}")

    device = "cpu" if self.cpu else torch.cuda.current_device()
    self.model.to(device)
    self.mssg.print(f"Model Class {self.model.__class__} | Type: {type(self.model)}")

    # Eval - pre 
    if eval_pre:
        if self.time_flag: timer.start()
        self.mssg.print(f"Eval Pre | wlen {X.shape[2]} | Model: {self.model.__class__} | {type(self.model)} ")
        self.model.eval()
        results = validate_with_metrics(self.model, self.metrics)
        eval_results_pre = mvp_format_results(results)
        if self.time_flag:
            timer.end()
            t_eval_1 = timer.duration()
            timer.show(verbose = self.mssg.verbose)
    # Train 
    if shot:
        if self.time_flag: timer.start()
        self.model.train()
        self.mssg.print(f"Training the model | window size {X.shape[2]} | X ~ {X.shape}")
        lr_valley, lr_steep = self.model.lr_find(suggest_funcs=(valley, steep), show_plot=show_plot)
        self.model.fit_one_cycle(
            n_epoch = self.num_epochs, 
            lr_max  = lr_valley,  
            cbs     = cbs2
        )
        losses = self.model.recorder.losses
        if self.time_flag:
            timer.end()
            t_shot= timer.duration()
            timer.show(verbose = self.mssg.verbose)

    # Eval - post
    if eval_post:
        if self.time_flag: timer.start()
        self.mssg.print(f"Eval Pre | wlen {X.shape[2]}")
        self.model.eval()
        results = validate_with_metrics(self.model, self.metrics)
        self.mssg.print(f"Format results | results~{len(results)}")
        eval_results_post = mvp_format_results(results)
        if self.time_flag:
            timer.end()
            t_eval_2 = timer.duration()
            timer.show(verbose = self.mssg.verbose)
    self.mssg.final()
    return losses, eval_results_pre, eval_results_post, t_shot, t_eval_1, t_eval_2, self.model
Encoder.fine_tune_mvp_single_ = fine_tune_mvp_single_

In [None]:
#| export
# TODO: Revisar inclusion del optimizer en fine_tune_mvp_

def fine_tune_mvp_(
    self                    : Encoder,
    eval_pre                : bool  = True,
    eval_post               : bool  = True,
    shot                    : bool  = False,
    time_flag               : bool  = None,
    use_wandb               : bool  = None,
    analysis_mode           : str   = None,
    norm_by_sample          : bool  = None,
    norm_use_single_batch   : bool  = None,
    show_plot               : bool  = None
):
    self.mssg.initial_("fine_tune_mvp_")
    self.time_flag      = self.time_flag if time_flag is None else time_flag
    self.use_wandb      = self.use_wandb if use_wandb is None else use_wandb
    self.analysis_mode  = self.analysis_mode if analysis_mode is None else analysis_mode
    self.norm_by_sample = self.norm_by_sample if norm_by_sample is None else norm_by_sample
    self.norm_use_single_batch = self.norm_use_single_batch if norm_use_single_batch is None else norm_use_single_batch
    # Return values
    lossess             = []
    eval_results_pre    = []
    eval_results_post   = []
    t_shots             = []
    t_shot              = 0
    t_evals             = []
    t_eval              = 0

    if self.input.size is None:
        self.mssg.print(f"Windows: {len(self.input._data)}")
        raise ValueError(f"Invalid number of windows: {self.input.size}")
    self.mssg.print(f"Processing {self.input.size} datasets: {self.input.shapes}")
    # Build optimizer
    if self.optim.optimizer is None: 
        self.mssg.print(f"Setting up optimizer as AdamW")
        if (not isinstance(self.optim.lr, float)):
            self.optim.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.optim.lr.lr)
        else:
            self.optim.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
    # Compute model for each window in the windowed dataset
    for i in range(self.input.size):
        self.mssg.print(f"Processing wlen {self.input.shape[2]}")
        ( 
            losses, eval_results_pre_, eval_results_post_, t_shot_, t_eval_1, t_eval_2, self.model
        ) =  self.fine_tune_mvp_single_(eval_pre, eval_post, shot, sample_id = i, show_plot = show_plot)
        lossess.append(losses)
        if (eval_pre): eval_results_pre = eval_results_pre_
        eval_results_post.append(eval_results_post_)
        t_shots.append(t_shot_)
        if eval_pre: t_evals.append(t_eval_1)
        if eval_post: t_evals.append(t_eval_2)
        eval_pre = False
    t_shot = sum(t_shots)
    t_eval = sum(t_evals)
    self.mssg.final(ut.funcname())
    return lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval, self.model

Encoder.fine_tune_mvp_ = fine_tune_mvp_ 

### Moirai

### Global method

In [None]:
#| export
def fine_tune__old(
    X                               : Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ],
    enc_learn                       : Learner, 
    stride                          : int           = 1,      
    batch_size                      : int           = 32,
    cpu                             : bool          = False,
    to_numpy                        : bool          = True, 
    verbose                         : int           = 0, 
    time_flag                       : bool          = False,
    n_windows                       : int           = None,
    n_windows_percent               : float         = None,
    validation_percent              : float         = 0.2, 
    training_percent                : float         = 0.2,
    window_mask_percent             : float         = 0.3,
    num_epochs                      : int           = 3,
    shot                            : bool          = True,
    eval_pre                        : bool          = True,
    eval_post                       : bool          = True,
    criterion                       : _Loss         = torch.nn.MSELoss, 
    optimizer                                       = None, 
    lr                              : float         = 5e-5, #1e-4, 
    lr_scheduler_flag               : bool          = False, 
    lr_scheduler_name               : str           = "linear",
    lr_scheduler_num_warmup_steps   : int           = None,
    window_sizes                    : List [int]    = None,
    n_window_sizes                  : int           = 1,
    window_sizes_offset             : int           = 0.05,
    windows_min_distance            : int           = 1,
    full_dataset                    : bool          = False,
    #- Printing options for debugging
    print_to_path                   : bool          = False,
    print_path                      : str           = "~/data/logs/logs.txt",
    print_mode                      : str           = 'a',
    #- Only for moment
    use_moment_masks                : bool          = False,
    #- Masking options
    mask_stateful                   : bool          = False,
    mask_future                     : bool          = False,
    mask_sync                       : bool          = False
): 
    mssg = ut.Mssg(
            to_path=print_to_path,
            path=print_path,
            mode=print_mode,
            verbose=verbose
        ) 
    mssg.initial()
    
    lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval = ( None, None, None, None, None, None, None )
    
    enc_input = windowed_dataset(
        X, stride, window_sizes, 
        n_window_sizes, window_sizes_offset, 
        windows_min_distance, full_dataset, 
        mssg
    )
    enc_input = EncoderInput(
        _data               = enc_input, 
        stride              = stride,
        batch_size          = batch_size,
        n_windows           = n_windows,
        n_windows_percent   = n_windows_percent,
        validation_percent  = validation_percent,
        training_percent    = training_percent,
        window_mask_percent = window_mask_percent,
    )
    optim = EncoderOptimizer(
        criterion   = criterion,
        optimizer   = optimizer,
        lr          = LRScheduler (
                        flag            = lr_scheduler_flag,
                        name            = lr_scheduler_name,
                        num_warmup_steps= lr_scheduler_num_warmup_steps
        ),
    )
    enc = Encoder(
        model           = enc_learn,
        input           = enc_input,
        mssg            = mssg,
        cpu             = cpu,
        to_numpy        = to_numpy, 
        num_epochs      = num_epochs, 
        optim           = optim,
        mask_stateful   = mask_stateful,
        mask_future     = mask_future,
        mask_sync       = mask_sync,
        eval_stats_pre  = eval_results_pre,
        eval_stats_post = eval_results_post
    )
    enc.set_fine_tune_()
    match enc.fine_tune_.__name__:
        case "fine_tune_moment_":
            ( 
                lossess, eval_results_pre, eval_results_post, 
                t_shots, t_shot, t_evals, t_eval, enc.model 
            ) = enc.fine_tune_(
                eval_pre, eval_post, shot, time_flag, use_moment_masks
            )
        case _:
            ( 
                lossess, eval_results_pre, eval_results_post, 
                t_shots, t_shot, t_evals, t_eval, enc.model 
            ) = enc.fine_tune_(eval_pre, eval_post, shot, time_flag)
    return lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval, enc.model

In [None]:
#| export
def _get_mssg(
    mssg : ut.Mssg = None,
    verbose                         : int           = 0, 
    print_to_path                   : bool          = False,
    print_path                      : str           = "~/data/logs/logs.txt",
    print_mode                      : str           = 'a',
):
    mssg,_ = ut._check_value(mssg, None, "mssg", ut.Mssg)
    if mssg is None:
        mssg = ut.Mssg(
            to_path = print_to_path,
            path    = print_path,
            mode    = print_mode,
            verbose = verbose
        ) 
    return mssg

def _get_enc_input(
    mssg                            : ut.Mssg,
    # Encoder Input
    ## -- Using all parammeters
    X                               : Optional [ Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ] ],
    stride                          : Optional [ int ]          = None,
    batch_size                      : Optional [ int ]          = None,
    n_windows                       : Optional [ int ]          = None,
    n_windows_percent               : Optional [ float ]        = None,
    validation_percent              : Optional [ float ]        = None, 
    training_percent                : Optional [ float ]        = None,
    window_mask_percent             : Optional [ float ]        = None,
    window_sizes                    : Optional [ List [int] ]   = None,
    n_window_sizes                  : Optional [ int ]          = 1,
    window_sizes_offset             : Optional [ int ]          = 0.05,
    windows_min_distance            : Optional [ int ]          = 1,
    full_dataset                    : Optional [ bool ]         = False,
    ## -- Using Type
    enc_input                       : Optional [ EncoderInput ] = None
): 
    mssg.initial_(func_name = ut.funcname())
    enc_input, _ = ut._check_value(enc_input, None, "enc_input", EncoderInput, True, False, False)
    mssg.print(f"is none enc_input? {enc_input is None}")
    if enc_input is None:
        mssg.print(f"About to get the windows")
        enc_input = windowed_dataset(
            X                       = X,
            stride                  = stride,
            window_sizes            = window_sizes,
            n_window_sizes          = n_window_sizes,
            window_sizes_offset     = window_sizes_offset,
            windows_min_distance    = windows_min_distance,
            full_dataset            = full_dataset,
            mssg                    = mssg
        )
        mssg.print(f"About to get the encoder input | windows~{len(enc_input)}", func_name = ut.funcname())
        enc_input = EncoderInput(
            _data               = enc_input, 
            stride              = stride,
            batch_size          = batch_size,
            n_windows           = n_windows,
            n_windows_percent   = n_windows_percent,
            validation_percent  = validation_percent,
            training_percent    = training_percent,
            window_mask_percent = window_mask_percent,
        )
        mssg.print(f"Enc input obtained | enc_input~{enc_input.shape}")
    mssg.final()
    return enc_input

def _get_optimizer(
    mssg                            : ut.Mssg,
    optim                           : EncoderOptimizer = None,
    criterion                       : _Loss         = torch.nn.MSELoss, 
    optimizer                                       = None, 
    lr                              : float         = 5e-5, #1e-4, 
    lr_scheduler_flag               : bool          = False, 
    lr_scheduler_name               : str           = "linear",
    lr_scheduler_num_warmup_steps   : int           = None
):
    mssg.initial(ut.funcname())
    optim,_ = ut._check_value(optim, None, "optim", EncoderOptimizer, True)
    if optim is None:
        optim = EncoderOptimizer(
            criterion   = criterion,
            optimizer   = optimizer,
            lr          = LRScheduler (
                            lr              = lr,
                            flag            = lr_scheduler_flag,
                            name            = lr_scheduler_name,
                            num_warmup_steps= lr_scheduler_num_warmup_steps
            ),
        )
    mssg.final()
    return optim

def _get_encoder(
    ## -- Using all parammeters
    X                               : Optional [ Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ] ],
    stride                          : Optional [ int ]          = None,
    batch_size                      : Optional [ int ]          = None,
    n_windows                       : Optional [ int ]          = None,
    n_windows_percent               : Optional [ float ]        = None,
    validation_percent              : Optional [ float ]        = None, 
    training_percent                : Optional [ float ]        = None,
    window_mask_percent             : Optional [ float ]        = None,
    window_sizes                    : Optional [ List [int] ]   = None,
    n_window_sizes                  : Optional [ int ]          = 1,
    window_sizes_offset             : Optional [ int ]          = 0.05,
    windows_min_distance            : Optional [ int ]          = 1,
    full_dataset                    : Optional [ bool ]         = False,
    ##-- Given by Type 
    enc_input                       : Optional [ EncoderInput ] = None,
    # Optimizer
    optim                           : Optional [ EncoderOptimizer ] = None,
    ## -- Using all parameters
    criterion                       : Optional [ _Loss ]            = torch.nn.MSELoss, 
    optimizer                                                       = None, 
    lr                              : Optional [ float ]            = 5e-5, #1e-4, 
    lr_scheduler_flag               : Optional [ bool ]             = False, 
    lr_scheduler_name               : Optional [ str ]              = "linear",
    lr_scheduler_num_warmup_steps   : Optional [ int ]              = None,
    # Mssg
    ## -- Using all parameters
    verbose                         : Optional[ int ]               = 0, 
    print_to_path                   : Optional[ bool ]              = False,
    print_path                      : Optional[ str ]               = "~/data/logs/logs.txt",
    print_mode                      : Optional[ str ]               = 'a',
    ## -- Using Type
    mssg                            : Optional [ ut.Mssg ]          = None,
    ## Encoder 
    enc                             : Optional [ Encoder ]          = None,
    ## -- Using all parameters
    num_epochs                      : Optional [ int]               = 3,
    enc_learn                       : Optional [Learner]            = None, 
    cpu                             : Optional [ bool ]             = False,
    to_numpy                        : Optional [ bool ]             = True,
    #- Masking options
    mask_stateful                   : Optional [ bool ]             = False,
    mask_future                     : Optional [ bool ]             = False,
    mask_sync                       : Optional [ bool ]             = False,
    #- Loss criterions
    metrics                         : Optional [ List [ Callable ]] = None
):
    enc,_ = ut._check_value(enc, None, "enc", Encoder, True)
    
    if enc is None: 
        mssg = _get_mssg(mssg, verbose, print_to_path, print_path, print_mode)
        mssg.initial(ut.funcname())
        mssg.print("About to exec _get_enc_input")
        enc_input = _get_enc_input(mssg, X, stride, batch_size, n_windows, n_windows_percent, validation_percent, training_percent, window_mask_percent, window_sizes, n_window_sizes, window_sizes_offset, windows_min_distance, full_dataset, enc_input)
        mssg.print(f"enc_input~{enc_input.shape}")
        mssg.print("About to exec _get_optimizer")
        optim = _get_optimizer(mssg, optim, criterion, optimizer, lr, lr_scheduler_flag, lr_scheduler_name, lr_scheduler_num_warmup_steps)
        enc = Encoder(
            model           = enc_learn,
            input           = enc_input,
            mssg            = mssg,
            cpu             = cpu,
            to_numpy        = to_numpy, 
            num_epochs      = num_epochs, 
            optim           = optim,
            mask_stateful   = mask_stateful,
            mask_future     = mask_future,
            mask_sync       = mask_sync,
            eval_stats_pre  = None,
            eval_stats_post = None,
            metrics         = metrics
        )
    enc.mssg.final(ut.funcname())
    return enc

In [None]:
#| export
def fine_tune(
    # Optional parameters
    ## Encoder Input
    ## -- Using all parammeters
    X                               : Optional [ Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ] ],
    stride                          : Optional [ int ]          = None,
    batch_size                      : Optional [ int ]          = None,
    n_windows                       : Optional [ int ]          = None,
    n_windows_percent               : Optional [ float ]        = None,
    validation_percent              : Optional [ float ]        = None, 
    training_percent                : Optional [ float ]        = None,
    window_mask_percent             : Optional [ float ]        = None,
    window_sizes                    : Optional [ List [int] ]   = None,
    n_window_sizes                  : Optional [ int ]          = 1,
    window_sizes_offset             : Optional [ int ]          = 0.05,
    windows_min_distance            : Optional [ int ]          = 1,
    full_dataset                    : Optional [ bool ]         = False,
    ##-- Given by Type 
    enc_input                       : Optional [ EncoderInput ] = None,
    # Optimizer
    optim                           : Optional [ EncoderOptimizer ] = None,
    ## -- Using all parameters
    criterion                       : Optional [ _Loss ]            = torch.nn.MSELoss, 
    optimizer                                                       = None, 
    lr                              : Optional [ float ]            = 5e-5, #1e-4, 
    lr_scheduler_flag               : Optional [ bool ]             = False, 
    lr_scheduler_name               : Optional [ str ]              = "linear",
    lr_scheduler_num_warmup_steps   : Optional [ int ]              = None,
    # Mssg
    ## -- Using all parameters
    verbose                         : Optional[ int ]               = 0, 
    print_to_path                   : Optional[ bool ]              = False,
    print_path                      : Optional[ str ]               = "~/data/logs/logs.txt",
    print_mode                      : Optional[ str ]               = 'a',
    ## -- Using Type
    mssg                            : Optional [ ut.Mssg ]          = None,
    
    ## Encoder 
    enc                             : Optional [ Encoder ]          = None,
    ## -- Using all parameters
    num_epochs                      : Optional [ int]               = 3,
    enc_learn                       : Optional [Learner]            = None, 
    cpu                             : Optional [ bool ]             = False,
    to_numpy                        : Optional [ bool ]             = True,
    #- Only for moment
    use_moment_masks                : Optional [ bool ]             = False,
    #- Masking options
    mask_stateful                   : Optional [ bool ]             = False,
    mask_future                     : Optional [ bool ]             = False,
    mask_sync                       : Optional [ bool ]             = False,
    # Non-Optional parameters
    time_flag                       : bool          = False,
    shot                            : bool          = True, 
    eval_pre                        : bool          = True, 
    eval_post                       : bool          = True,
    use_wandb                       : bool          = None,
    analysis_mode                   : str           = None,
    norm_by_sample                  : bool          = None,
    norm_use_single_batch           : bool          = None,
    show_plot                       : bool          = False,
    metrics                                         = None
): 
    enc = _get_encoder(
        X                               = X,
        stride                          = stride,
        batch_size                      = batch_size,
        n_windows                       = n_windows,
        n_windows_percent               = n_windows_percent,
        validation_percent              = validation_percent,
        training_percent                = training_percent,
        window_mask_percent             = window_mask_percent,
        window_sizes                    = window_sizes,
        n_window_sizes                  = n_window_sizes,
        window_sizes_offset             = window_sizes_offset,
        windows_min_distance            = windows_min_distance,
        full_dataset                    = full_dataset,
        enc_input                       = enc_input,
        optim                           = optim,
        criterion                       = criterion,
        optimizer                       = optimizer,
        lr                              = lr,
        lr_scheduler_flag               = lr_scheduler_flag,
        lr_scheduler_name               = lr_scheduler_name,
        lr_scheduler_num_warmup_steps   = lr_scheduler_num_warmup_steps,
        verbose                         = verbose,
        print_to_path                   = print_to_path,
        print_path                      = print_path,
        print_mode                      = print_mode,
        mssg                            = mssg,
        enc                             = enc,
        num_epochs                      = num_epochs,
        enc_learn                       = enc_learn,
        cpu                             = cpu,
        to_numpy                        = to_numpy,
        mask_stateful                   = mask_stateful,
        mask_future                     = mask_future,
        mask_sync                       = mask_sync,
        metrics                         = metrics        
)
    enc.mssg.initial_("fine_tune")
    enc.mssg.print(f"Original enc_learn { enc_learn }  | Final model { enc.model }")
    lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval = ( None, None, None, None, None, None, None )
    enc.set_fine_tune_()
    match enc.fine_tune_.__name__:
        case "fine_tune_moment_":
            enc.mssg.print("Use fine_tune_moment parameters")
            ( 
                lossess, eval_results_pre, eval_results_post, 
                t_shots, t_shot, t_evals, t_eval, enc.model 
            ) = enc.fine_tune_(
                eval_pre, eval_post, shot, time_flag, use_moment_masks
            )
        case "fine_tune_mvp_":
            enc.mssg.print("Use fine_tune_mvp parameters")
            ( 
                lossess, eval_results_pre, eval_results_post, 
                t_shots, t_shot, t_evals, t_eval, enc.model 
            ) = enc.fine_tune_(eval_pre, eval_post, shot, time_flag, use_wandb = use_wandb, analysis_mode = analysis_mode, norm_by_sample = norm_by_sample, norm_use_single_batch = norm_use_single_batch, show_plot = show_plot)
        case _:
            enc.mssg.print("Use generic fine_tune parameters")
            ( 
                lossess, eval_results_pre, eval_results_post, 
                t_shots, t_shot, t_evals, t_eval, enc.model 
            ) = enc.fine_tune_(eval_pre, eval_post, shot, time_flag)
    enc.mssg.final()
    return lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval, enc.model

In [None]:
def fine_tune__(
    self                            : Encoder,
    # Optional parameters
    ## Encoder Input
    ## -- Using all parammeters
    X                               : Optional [ Union [ List [ List [ List [ float ]]], List [ float ], pd.DataFrame ] ],
    stride                          : Optional [ int ]          = None,
    batch_size                      : Optional [ int ]          = None,
    n_windows                       : Optional [ int ]          = None,
    n_windows_percent               : Optional [ float ]        = None,
    validation_percent              : Optional [ float ]        = None, 
    training_percent                : Optional [ float ]        = None,
    window_mask_percent             : Optional [ float ]        = None,
    window_sizes                    : Optional [ List [int] ]   = None,
    n_window_sizes                  : Optional [ int ]          = 1,
    window_sizes_offset             : Optional [ int ]          = 0.05,
    windows_min_distance            : Optional [ int ]          = 1,
    full_dataset                    : Optional [ bool ]         = False,
    ##-- Given by Type 
    enc_input                       : Optional [ EncoderInput ] = None,
    # Optimizer
    optim                           : Optional [ EncoderOptimizer ] = None,
    ## -- Using all parameters
    criterion                       : Optional [ _Loss ]            = torch.nn.MSELoss, 
    optimizer                                                       = None, 
    lr                              : Optional [ float ]            = 5e-5, #1e-4, 
    lr_scheduler_flag               : Optional [ bool ]             = False, 
    lr_scheduler_name               : Optional [ str ]              = "linear",
    lr_scheduler_num_warmup_steps   : Optional [ int ]              = None,
    # Mssg
    ## -- Using all parameters
    verbose                         : Optional[ int ]               = 0, 
    print_to_path                   : Optional[ bool ]              = False,
    print_path                      : Optional[ str ]               = "~/data/logs/logs.txt",
    print_mode                      : Optional[ str ]               = 'a',
    ## -- Using Type
    mssg                            : Optional [ ut.Mssg ]          = None,
    # Non-Optional parameters
    use_moment_masks                : bool          = True,
    time_flag                       : bool          = False,
    shot                            : bool          = True, 
    eval_pre                        : bool          = True, 
    eval_post                       : bool          = True,
    # MVP
    use_wandb                       : bool          = None,
    norm_by_sample                  : bool          = None,
    norm_use_single_batch           : bool          = None
):
    if self.mssg == ut.Mssg():
        mssg = _get_mssg(mssg, verbose, print_to_path, print_path, print_mode)
    if self.input == EncoderInput():
        enc_input = _get_enc_input(mssg, X, stride, batch_size, n_windows, n_windows_percent, validation_percent, training_percent, window_mask_percent, window_sizes, n_window_sizes, window_sizes_offset, windows_min_distance, full_dataset, enc_input)
    if self.optim == EncoderOptimizer():
        optim = _get_optimizer(mssg, optim, criterion, optimizer, lr, lr_scheduler_flag, lr_scheduler_name, lr_scheduler_num_warmup_steps)
    
    self.input  = enc_input
    self.mssg   = mssg 
    self.optim  = optim
    lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval = ( None, None, None, None, None, None, None )
    self.set_fine_tune_()
    
    if self.fine_tune_ == fine_tune_moment_:
        ( 
            lossess, eval_results_pre, eval_results_post, 
            t_shots, t_shot, t_evals, t_eval, self.model 
        ) = self.fine_tune_(
            eval_pre, eval_post, shot, time_flag, use_moment_masks
        )
    elif self.fine_tune_ == fine_tune_mvp_:
        ( 
            lossess, eval_results_pre, eval_results_post, 
            t_shots, t_shot, t_evals, t_eval, self.model
        ) = self.fine_tune_(
            eval_pre, eval_post, shot, time_flag, use_wandb, norm_by_sample, norm_use_single_batch
        )
    else:
        ( 
            lossess, eval_results_pre, eval_results_post, 
            t_shots, t_shot, t_evals, t_eval, self.model
        ) = self.fine_tune_(eval_pre, eval_post, shot, time_flag)
    mssg.final()
    return lossess, eval_results_pre, eval_results_post, t_shots, t_shot, t_evals, t_eval, self.model
Encoder.fine_tune = fine_tune__

In [54]:
#| hide
#import wandb
#from dvats.utils import *
#wandb_api = wandb.Api()
#enc_artifact = wandb_api.artifact('deepvats/mvp-SWV:latest')
#enc_learner = enc_artifact.to_obj()
#X = torch.rand(9, 1, 48)

In [55]:
#| hide
#%time
#embs = get_enc_embs(X, enc_learner, cpu=True)
#test_eq(embs.shape[0], X.shape[0])
#embs.shape, embs.__class__

In [56]:
#| hide
#%%time #TODO dont work with nb2py
#embs = get_enc_embs(X, enc_learner, cpu=False, to_numpy=False)
#test_eq(embs.shape[0], X.shape[0])
#embs.shape, embs.__class__, embs.device

In [None]:
#| hide 
#%%time #TODO --> dont works with nb2py
#embs = get_enc_embs(X, enc_learner, cpu=False, to_numpy=True)
#test_eq(embs.shape[0], X.shape[0])
#embs.shape, embs.__class__

In [57]:
#| hide
from dvats.imports import beep
beep(1)
beep(1)
beep(1)
beep(1)
beep(1)