In [None]:
#|default_exp models.multimodal

# Multimodal

>Functionality used for multiple data modalities.

A common scenario in time-series related tasks is the use of multiple types of inputs:

* static: data that doesn't change with time
* observed: temporal data only available in the past
* known: temporal data available in the past and in the future

At the same time, these different modalities may contain:

* categorical data
* continuous or numerical data

Based on that, there are situations where we have up to 6 different types of input features:

* s_cat: static continuous variables
* o_cat: observed categorical variables
* o_cont: observed continuous variables
* k_cat: known categorical variables
* k_cont: known continuous variables

In [None]:
#| export
import torch
import torch.nn as nn
import numpy as np
from collections import OrderedDict
from fastcore.test import test_eq
from fastcore.xtras import listify
from fastcore.xtras import L
from fastai.tabular.model import emb_sz_rule
from tsai.imports import default_device
from tsai.data.core import TSDataLoaders
from tsai.data.preprocessing import PatchEncoder
from tsai.learner import get_arch
from tsai.models.utils import build_ts_model, output_size_calculator
from tsai.models.layers import Reshape, LinBnDrop, get_act_fn, lin_nd_head, rocket_nd_head, GAP1d

In [None]:
#| export
def _to_list(idx):
    if idx is None:
        return []
    elif isinstance(idx, int):
        return [idx]
    elif isinstance(idx, list):
        return idx


def get_o_cont_idxs(c_in, s_cat_idxs=None, s_cont_idxs=None, o_cat_idxs=None):
    "Calculate the indices of the observed continuous features."
    all_features = np.arange(c_in).tolist()
    for idxs in [s_cat_idxs, s_cont_idxs, o_cat_idxs]:
        if idxs is not None:
            if not isinstance(idxs, list): idxs = [idxs]
            for idx in idxs:
                all_features.remove(idx)
    o_cont_idxs = all_features
    return o_cont_idxs


def get_feat_idxs(c_in, s_cat_idxs=None, s_cont_idxs=None, o_cat_idxs=None, o_cont_idxs=None):
    "Calculate the indices of the features used for training."
    idx_list = [s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs]
    s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs = list(map(_to_list, idx_list))
    if not o_cont_idxs:
        o_cont_idxs = get_o_cont_idxs(c_in, s_cat_idxs=s_cat_idxs, s_cont_idxs=s_cont_idxs, o_cat_idxs=o_cat_idxs)
    return s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs

In [None]:
c_in = 7
s_cat_idxs = 3
s_cont_idxs = [1, 4, 5]
o_cat_idxs = None
o_cont_idxs = None

s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs = get_feat_idxs(c_in, s_cat_idxs=s_cat_idxs, s_cont_idxs=s_cont_idxs, o_cat_idxs=o_cat_idxs, o_cont_idxs=o_cont_idxs)

test_eq(s_cat_idxs, [3])
test_eq(s_cont_idxs, [1, 4, 5])
test_eq(o_cat_idxs, [])
test_eq(o_cont_idxs, [0, 2, 6])

In [None]:
#| export
class TensorSplitter(nn.Module):
    def __init__(self,
        s_cat_idxs:list=None, # list of indices for static categorical variables
        s_cont_idxs:list=None, # list of indices for static continuous variables
        o_cat_idxs:list=None, # list of indices for observed categorical variables
        o_cont_idxs:list=None, # list of indices for observed continuous variables
        k_cat_idxs:list=None, # list of indices for known categorical variables
        k_cont_idxs:list=None, # list of indices for known continuous variables
        horizon:int=None, # number of time steps to predict ahead
        ):
        super(TensorSplitter, self).__init__()
        assert s_cat_idxs or s_cont_idxs or o_cat_idxs or o_cont_idxs, "must specify at least one of s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs"
        if k_cat_idxs or k_cont_idxs:
            assert horizon is not None, "must specify horizon if using known variables"
        assert horizon is None or isinstance(horizon, int), "horizon must be an integer"
        self.s_cat_idxs = self._to_list(s_cat_idxs)
        self.s_cont_idxs = self._to_list(s_cont_idxs)
        self.o_cat_idxs = self._to_list(o_cat_idxs)
        self.o_cont_idxs = self._to_list(o_cont_idxs)
        self.k_cat_idxs = self._to_list(k_cat_idxs)
        self.k_cont_idxs = self._to_list(k_cont_idxs)
        idx_list = [self.s_cat_idxs, self.s_cont_idxs, self.o_cat_idxs, self.o_cont_idxs]
        if horizon:
            idx_list += [self.k_cat_idxs, self.k_cont_idxs]
        self.idx_list = list(map(self._to_list, idx_list))
        self._check_overlap()
        self.horizon = horizon

    def _check_overlap(self):
        indices = []
        for idx in self.idx_list:
            indices += idx
        if len(indices) != len(set(indices)):
            raise ValueError("Indices must not overlap between s_cat_idxs, s_cont_idxs, o_cat_idxs, and o_cont_idxs")

    @staticmethod
    def _to_list(idx):
        if idx is None:
            return []
        elif isinstance(idx, int):
            return [idx]
        elif isinstance(idx, list):
            return idx

    def forward(self, input_tensor):
        slices = []
        for idx, idxs in enumerate(self.idx_list):
        # for idx, idxs in enumerate([self.s_cat_idxs, self.s_cont_idxs, self.o_cat_idxs, self.o_cont_idxs, self.k_cat_idxs, self.k_cont_idxs]):
            if idxs:
                if idx < 2:  # s_cat_idxs or s_cont_idxs
                    slices.append(input_tensor[:, idxs, 0].long())
                elif idx < 4 and self.horizon is not None:  # o_cat_idxs or o_cont_idxs and horizon is not None
                    slices.append(input_tensor[:, idxs, :-self.horizon])
                else:  # k_cat_idxs or k_cont_idxs or o_cat_idxs or o_cont_idxs and horizon is None
                    slices.append(input_tensor[:, idxs, :])
            else:
                if idx < 2:  # s_cat_idxs or s_cont_idxs
                    slices.append(torch.empty((input_tensor.size(0), 0), device=input_tensor.device))  # return 2D empty tensor
                elif idx < 4 and self.horizon is not None: # o_cat_idxs or o_cont_idxs and horizon is not None
                        slices.append(torch.empty((input_tensor.size(0), 0, input_tensor.size(2)-self.horizon), device=input_tensor.device))
                else:   # k_cat_idxs or k_cont_idxs or o_cat_idxs or o_cont_idxs and horizon is None
                    slices.append(torch.empty((input_tensor.size(0), 0, input_tensor.size(2)), device=input_tensor.device))
        return slices


In [None]:
# Example usage
bs = 4
s_cat_idxs = 1
s_cont_idxs = [0, 2]
o_cat_idxs =[ 3, 4, 5]
o_cont_idxs = None
k_cat_idxs = None
k_cont_idxs = None
horizon=None
input_tensor = torch.randn(bs, 6, 10)  # 3D input tensor
splitter = TensorSplitter(s_cat_idxs=s_cat_idxs, s_cont_idxs=s_cont_idxs,
                          o_cat_idxs=o_cat_idxs, o_cont_idxs=o_cont_idxs)
slices = splitter(input_tensor)
for i, slice_tensor in enumerate(slices):
    print(f"Slice {i+1}: {slice_tensor.shape} {slice_tensor.dtype}")

Slice 1: torch.Size([4, 1]) torch.int64
Slice 2: torch.Size([4, 2]) torch.int64
Slice 3: torch.Size([4, 3, 10]) torch.float32
Slice 4: torch.Size([4, 0, 10]) torch.float32


In [None]:
# Example usage
bs = 4
s_cat_idxs = 1
s_cont_idxs = [0, 2]
o_cat_idxs =[ 3, 4, 5]
o_cont_idxs = None
k_cat_idxs = [6,7]
k_cont_idxs = 8
horizon=3
input_tensor = torch.randn(4, 9, 10)  # 3D input tensor
splitter = TensorSplitter(s_cat_idxs=s_cat_idxs, s_cont_idxs=s_cont_idxs,
                          o_cat_idxs=o_cat_idxs, o_cont_idxs=o_cont_idxs,
                          k_cat_idxs=k_cat_idxs, k_cont_idxs=k_cont_idxs, horizon=horizon)
slices = splitter(input_tensor)
for i, slice_tensor in enumerate(slices):
    print(f"Slice {i+1}: {slice_tensor.shape} {slice_tensor.dtype}")


Slice 1: torch.Size([4, 1]) torch.int64
Slice 2: torch.Size([4, 2]) torch.int64
Slice 3: torch.Size([4, 3, 7]) torch.float32
Slice 4: torch.Size([4, 0, 7]) torch.float32
Slice 5: torch.Size([4, 2, 10]) torch.float32
Slice 6: torch.Size([4, 1, 10]) torch.float32


In [None]:
#| export
class Embeddings(nn.Module):
    "Embedding layers for each categorical variable in a 2D or 3D tensor"
    def __init__(self,
        n_embeddings:list, # List of num_embeddings for each categorical variable
        embedding_dims:list=None, # List of embedding dimensions for each categorical variable
        padding_idx:int=0, # Embedding padding_idx
        embed_dropout:float=0., # Dropout probability for `Embedding` layer
        **kwargs
        ):
        super().__init__()
        if not isinstance(n_embeddings, list): n_embeddings = [n_embeddings]
        if embedding_dims is None:
            embedding_dims = [emb_sz_rule(s) for s in n_embeddings]
        if not isinstance(embedding_dims, list): embedding_dims = [embedding_dims]
        embedding_dims = [emb_sz_rule(s) if s is None else s for s in n_embeddings]
        assert len(n_embeddings) == len(embedding_dims)
        self.embedding_dims = sum(embedding_dims)
        self.embedding_layers = nn.ModuleList([nn.Sequential(nn.Embedding(n,d,padding_idx=padding_idx, **kwargs),
                                                             nn.Dropout(embed_dropout)) for n,d in zip(n_embeddings, embedding_dims)])

    def forward(self, x):
        if x.ndim == 2:
            return torch.cat([e(x[:,i].long()) for i,e in enumerate(self.embedding_layers)],1)
        elif x.ndim == 3:
            return torch.cat([e(x[:,i].long()).transpose(1,2) for i,e in enumerate(self.embedding_layers)],1)

In [None]:
t1 = torch.randint(0, 7, (16, 1))
t2 = torch.randint(0, 5, (16, 1))
t = torch.cat([t1, t2], 1).float()
emb = Embeddings([7, 5], None, embed_dropout=0.1)
test_eq(emb(t).shape, (16, 12))

In [None]:
t1 = torch.randint(0, 7, (16, 1))
t2 = torch.randint(0, 5, (16, 1))
t = torch.cat([t1, t2], 1).float()
emb = Embeddings([7, 5], [4, 3])
test_eq(emb(t).shape, (16, 12))

In [None]:
t1 = torch.randint(0, 7, (16, 1, 10))
t2 = torch.randint(0, 5, (16, 1, 10))
t = torch.cat([t1, t2], 1).float()
emb = Embeddings([7, 5], None)
test_eq(emb(t).shape, (16, 12, 10))

In [None]:
#| export
class StaticBackbone(nn.Module):
    "Static backbone model to embed static features"
    def __init__(self, c_in, c_out, seq_len, d=None, layers=[200, 100], dropouts=[0.1, 0.2], act=nn.ReLU(inplace=True), use_bn=False, lin_first=False):
        super().__init__()
        layers, dropouts = L(layers), L(dropouts)
        if len(dropouts) <= 1: dropouts = dropouts * len(layers)
        assert len(layers) == len(dropouts), '#layers and #dropout must match'
        self.flatten = Reshape()
        nf = [c_in * seq_len] + layers
        self.mlp = nn.ModuleList()
        for i in range(len(layers)): self.mlp.append(LinBnDrop(nf[i], nf[i+1], bn=use_bn, p=dropouts[i], act=get_act_fn(act), lin_first=lin_first))
        self.head_nf = nf[-1]

    def forward(self, x):
        x = self.flatten(x)
        for mlp in self.mlp: x = mlp(x)
        return x

In [None]:
# Example usage
bs = 4
c_in = 6
c_out = 8
seq_len = 10
input_tensor = torch.randn(bs, c_in, seq_len)  # 3D input tensor
backbone = StaticBackbone(c_in, c_out, seq_len)
output_tensor = backbone(input_tensor)
print(f"Input shape: {input_tensor.shape} Output shape: {output_tensor.shape}")
backbone

Input shape: torch.Size([4, 6, 10]) Output shape: torch.Size([4, 100])


StaticBackbone(
  (flatten): Reshape(bs)
  (mlp): ModuleList(
    (0): LinBnDrop(
      (0): Dropout(p=0.1, inplace=False)
      (1): Linear(in_features=60, out_features=200, bias=True)
      (2): ReLU(inplace=True)
    )
    (1): LinBnDrop(
      (0): Dropout(p=0.2, inplace=False)
      (1): Linear(in_features=200, out_features=100, bias=True)
      (2): ReLU(inplace=True)
    )
  )
)

In [None]:

# class MultInputWrapper(nn.Module):
#     "Model wrapper for input tensors with static and/ or observed, categorical and/ or numerical features."

#     def __init__(self,
#         arch,
#         c_in:int=None, # number of input variables
#         c_out:int=None, # number of output variables
#         seq_len:int=None, # input sequence length
#         d:tuple=None, # shape of the output tensor
#         dls:TSDataLoaders=None, # TSDataLoaders object
#         s_cat_idxs:list=None, # list of indices for static categorical variables
#         s_cat_embeddings:list=None, # list of num_embeddings for each static categorical variable
#         s_cat_embedding_dims:list=None, # list of embedding dimensions for each static categorical variable
#         s_cont_idxs:list=None, # list of indices for static continuous variables
#         o_cat_idxs:list=None, # list of indices for observed categorical variables
#         o_cat_embeddings:list=None, # list of num_embeddings for each observed categorical variable
#         o_cat_embedding_dims:list=None, # list of embedding dimensions for each observed categorical variable
#         o_cont_idxs:list=None, # list of indices for observed continuous variables. All features not in s_cat_idxs, s_cont_idxs, o_cat_idxs are considered observed continuous variables.
#         patch_len:int=None, # Number of time steps in each patch.
#         patch_stride:int=None, # Stride of the patch.
#         flatten:bool=False, # boolean indicating whether to flatten bacbone's output tensor
#         use_bn:bool=False, # boolean indicating whether to use batch normalization in the head
#         fc_dropout:float=0., # dropout probability for the fully connected layer in the head
#         custom_head=None, # custom head to replace the default head
#         **kwargs
#     ):
#         super().__init__()

#         # attributes
#         c_in = c_in or dls.vars
#         c_out = c_out or dls.c
#         seq_len = seq_len or dls.len
#         d = d or (dls.d if dls is not None else None)
#         self.c_in, self.c_out, self.seq_len, self.d = c_in, c_out, seq_len, d

#         # tensor splitter
#         if o_cont_idxs is None:
#             o_cont_idxs = get_o_cont_idxs(c_in, s_cat_idxs=s_cat_idxs, s_cont_idxs=s_cont_idxs, o_cat_idxs=o_cat_idxs)
#         self.splitter = TensorSplitter(s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs)
#         s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs = self.splitter.s_cat_idxs, self.splitter.s_cont_idxs, self.splitter.o_cat_idxs, self.splitter.o_cont_idxs
#         assert c_in == sum([len(s_cat_idxs), len(s_cont_idxs), len(o_cat_idxs), len(o_cont_idxs)])

#         # embeddings
#         self.s_embeddings = Embeddings(s_cat_embeddings, s_cat_embedding_dims)
#         self.o_embeddings = Embeddings(o_cat_embeddings, o_cat_embedding_dims)

#         # patch encoder
#         if patch_len is not None:
#             patch_stride = patch_stride or patch_len
#             self.patch_encoder = PatchEncoder(patch_len, patch_stride, seq_len=seq_len)
#             c_mult = patch_len
#             seq_len = (seq_len + self.patch_encoder.pad_size - patch_len) // patch_stride + 1
#         else:
#             self.patch_encoder = nn.Identity()
#             c_mult = 1

#         # backbone
#         n_s_features = len(s_cont_idxs) + self.s_embeddings.embedding_dims
#         n_o_features = (len(o_cont_idxs) + self.o_embeddings.embedding_dims) * c_mult
#         s_backbone = StaticBackbone(c_in=n_s_features, c_out=c_out, seq_len=1, **kwargs)
#         if isinstance(arch, str):
#             arch = get_arch(arch)
#         if isinstance(arch, nn.Module):
#             o_model = arch
#         else:
#             o_model = build_ts_model(arch, c_in=n_o_features, c_out=c_out, seq_len=seq_len, d=d, **kwargs)
#         assert hasattr(o_model, "backbone"), "the selected arch must have a backbone"
#         o_backbone = getattr(o_model, "backbone")

#         # head
#         o_head_nf = output_size_calculator(o_backbone, n_o_features, seq_len)[0]
#         s_head_nf = s_backbone.head_nf
#         self.backbone = nn.ModuleList([o_backbone, s_backbone])
#         self.head_nf = o_head_nf + s_head_nf
#         if custom_head is not None:
#             if isinstance(custom_head, nn.Module): self.head = custom_head
#             else:self. head = custom_head(self.head_nf, c_out, seq_len, d=d)
#         else:
#             if "rocket" in o_model.__name__.lower():
#                 self.head = rocket_nd_head(self.head_nf, c_out, seq_len=seq_len, d=d, use_bn=use_bn, fc_dropout=fc_dropout)
#             else:
#                 self.head = lin_nd_head(self.head_nf, c_out, seq_len=seq_len, d=d, flatten=flatten, use_bn=use_bn, fc_dropout=fc_dropout)

#     def forward(self, x):
#         # split x into static cat, static cont, observed cat, and observed cont
#         s_cat, s_cont, o_cat, o_cont = self.splitter(x)

#         # create categorical embeddings
#         s_cat = self.s_embeddings(s_cat)
#         o_cat = self.o_embeddings(o_cat)

#         # contatenate static and observed features
#         s_x = torch.cat([s_cat, s_cont], 1)
#         o_x = torch.cat([o_cat, o_cont], 1)

#         # patch encoder
#         o_x = self.patch_encoder(o_x)

#         # pass static and observed features through their respective backbones
#         for i,(b,xi) in enumerate(zip(self.backbone, [o_x, s_x])):
#             if i == 0:
#                 x = b(xi)
#                 if x.ndim == 2:
#                     x = x[..., None]
#             else:
#                 x = torch.cat([x,  b(xi)[..., None].repeat(1, 1, x.shape[-1])], 1)

#         # head
#         x = self.head(x)
#         return x

In [None]:
# from tsai.models.InceptionTimePlus import InceptionTimePlus

In [None]:
# c_in = 6
# c_out = 3
# seq_len = 97
# d = None

# s_cat_idxs=2
# s_cont_idxs=4
# o_cat_idxs=[0, 3]
# o_cont_idxs=None
# s_cat_embeddings = 5
# s_cat_embedding_dims = None
# o_cat_embeddings = [7, 3]
# o_cat_embedding_dims = [3, None]

# t0 = torch.randint(0, 7, (16, 1, seq_len)) # cat
# t1 = torch.randn(16, 1, seq_len)
# t2 = torch.randint(0, 5, (16, 1, seq_len)) # cat
# t3 = torch.randint(0, 3, (16, 1, seq_len)) # cat
# t4 = torch.randn(16, 1, seq_len)
# t5 = torch.randn(16, 1, seq_len)

# t = torch.cat([t0, t1, t2, t3, t4, t5], 1).float()

# patch_lens = [None, 5, 5, 5, 5]
# patch_strides = [None, None, 1, 3, 5]
# for patch_len, patch_stride in zip(patch_lens, patch_strides):
#     for arch in ["InceptionTimePlus", InceptionTimePlus, "MultiRocketPlus"]:
#         print(f"arch: {arch}, patch_len: {patch_len}, patch_stride: {patch_stride}")

#         model = MultInputWrapper(
#             arch=arch,
#             c_in=c_in,
#             c_out=c_out,
#             seq_len=seq_len,
#             d=d,
#             s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims,
#             s_cont_idxs=s_cont_idxs,
#             o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims,
#             o_cont_idxs=o_cont_idxs,
#             patch_len=patch_len,
#             patch_stride=patch_stride,
#         )

#         test_eq(model(t).shape, (16,3))

In [None]:
#| export
class FusionMLP(nn.Module):
    def __init__(self, comb_dim, layers, act='relu', dropout=0., use_bn=True):
        super().__init__()
        self.avg_pool = GAP1d(1)
        layers = listify(layers)
        if not isinstance(dropout, list): dropout = [dropout]
        if len(dropout) != len(layers): dropout = dropout * len(layers)
        l = []
        for i,s in enumerate(layers):
            if use_bn: l.append(nn.BatchNorm1d(comb_dim if i == 0 else prev_s))
            if dropout[i]: l.append(nn.Dropout(dropout[i]))
            l.append(nn.Linear(comb_dim if i == 0 else prev_s, s))
            if act: l.append(get_act_fn(act))
            prev_s = s
        if l:
            self.mlp = nn.Sequential(*l)
        else:
            self.mlp = nn.Identity()

    def forward(self, x_cat, x_cont, x_emb):
        if x_emb.ndim == 3:
            x_emb = self.avg_pool(x_emb)
        output = torch.cat([x_cat, x_cont, x_emb], 1)
        output = self.mlp(output)
        return output

In [None]:
bs = 16
emb_dim = 128
seq_len = 20
cat_dim = 24
cont_feat = 3

comb_dim = emb_dim + cat_dim + cont_feat
emb = torch.randn(bs, emb_dim, seq_len)
cat = torch.randn(bs, cat_dim)
cont = torch.randn(bs, cont_feat)
fusion_mlp = FusionMLP(comb_dim, layers=comb_dim, act='relu', dropout=.1)
output = fusion_mlp(cat, cont, emb)
test_eq(output.shape, (bs, comb_dim))

In [None]:
bs = 16
emb_dim = 50000
cat_dim = 24
cont_feat = 3

comb_dim = emb_dim + cat_dim + cont_feat
emb = torch.randn(bs, emb_dim)
cat = torch.randn(bs, cat_dim)
cont = torch.randn(bs, cont_feat)
fusion_mlp = FusionMLP(comb_dim, layers=[128], act='relu', dropout=.1)
output = fusion_mlp(cat, cont, emb)
test_eq(output.shape, (bs, 128))

In [None]:
#| export
class MultInputBackboneWrapper(nn.Module):
    "Model backbone wrapper for input tensors with static and/ or observed, categorical and/ or numerical features."

    def __init__(self,
        arch,
        c_in:int=None, # number of input variables
        seq_len:int=None, # input sequence length
        d:tuple=None, # shape of the output tensor
        dls:TSDataLoaders=None, # TSDataLoaders object
        s_cat_idxs:list=None, # list of indices for static categorical variables
        s_cat_embeddings:list=None, # list of num_embeddings for each static categorical variable
        s_cat_embedding_dims:list=None, # list of embedding dimensions for each static categorical variable
        s_cont_idxs:list=None, # list of indices for static continuous variables
        o_cat_idxs:list=None, # list of indices for observed categorical variables
        o_cat_embeddings:list=None, # list of num_embeddings for each observed categorical variable
        o_cat_embedding_dims:list=None, # list of embedding dimensions for each observed categorical variable
        o_cont_idxs:list=None, # list of indices for observed continuous variables. All features not in s_cat_idxs, s_cont_idxs, o_cat_idxs are considered observed continuous variables.
        patch_len:int=None, # Number of time steps in each patch.
        patch_stride:int=None, # Stride of the patch.
        fusion_layers:list=[128], # list of layer dimensions for the fusion MLP
        fusion_act:str='relu', # activation function for the fusion MLP
        fusion_dropout:float=0., # dropout probability for the fusion MLP
        fusion_use_bn:bool=True, # boolean indicating whether to use batch normalization in the fusion MLP
        **kwargs
    ):
        super().__init__()

        # attributes
        c_in = c_in or dls.vars
        seq_len = seq_len or dls.len
        d = d or (dls.d if dls is not None else None)
        self.c_in, self.seq_len, self.d = c_in, seq_len, d

        # tensor splitter
        if o_cont_idxs is None:
            o_cont_idxs = get_o_cont_idxs(c_in, s_cat_idxs=s_cat_idxs, s_cont_idxs=s_cont_idxs, o_cat_idxs=o_cat_idxs)
        self.splitter = TensorSplitter(s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs)
        s_cat_idxs, s_cont_idxs, o_cat_idxs, o_cont_idxs = self.splitter.s_cat_idxs, self.splitter.s_cont_idxs, self.splitter.o_cat_idxs, self.splitter.o_cont_idxs
        assert c_in == sum([len(s_cat_idxs), len(s_cont_idxs), len(o_cat_idxs), len(o_cont_idxs)])

        # embeddings
        self.s_embeddings = Embeddings(s_cat_embeddings, s_cat_embedding_dims) if s_cat_idxs else nn.Identity()
        self.o_embeddings = Embeddings(o_cat_embeddings, o_cat_embedding_dims) if o_cat_idxs else nn.Identity()

        # patch encoder
        if patch_len is not None:
            patch_stride = patch_stride or patch_len
            self.patch_encoder = PatchEncoder(patch_len, patch_stride, seq_len=seq_len)
            c_mult = patch_len
            seq_len = (seq_len + self.patch_encoder.pad_size - patch_len) // patch_stride + 1
        else:
            self.patch_encoder = nn.Identity()
            c_mult = 1

        # backbone
        n_s_features = len(s_cont_idxs) + (self.s_embeddings.embedding_dims if s_cat_idxs else 0)
        n_o_features = (len(o_cont_idxs) + (self.o_embeddings.embedding_dims if o_cat_idxs else 0)) * c_mult
        if isinstance(arch, str):
            arch = get_arch(arch)
        if isinstance(arch, nn.Module):
            o_model = arch
        else:
            o_model = build_ts_model(arch, c_in=n_o_features, c_out=1, seq_len=seq_len, d=d, **kwargs)
        assert hasattr(o_model, "backbone"), "the selected arch must have a backbone"
        o_backbone = getattr(o_model, "backbone")
        self.o_backbone = o_backbone
        backbone_features = output_size_calculator(o_backbone, n_o_features, seq_len)[0]

        # fusion layer
        fusion_layers = listify(fusion_layers)
        self.fusion_layer = FusionMLP(n_s_features + backbone_features, layers=fusion_layers, act=fusion_act, dropout=fusion_dropout, use_bn=fusion_use_bn)
        self.head_nf = fusion_layers[-1]


    def forward(self, x):
        # split x into static cat, static cont, observed cat, and observed cont
        s_cat, s_cont, o_cat, o_cont = self.splitter(x)

        # create categorical embeddings
        s_cat = self.s_embeddings(s_cat)
        o_cat = self.o_embeddings(o_cat)

        # contatenate observed features
        o_x = torch.cat([o_cat, o_cont], 1)

        # patch encoder
        o_x = self.patch_encoder(o_x)

        # pass static and observed features through their respective backbones
        o_x = self.o_backbone(o_x)

        # fusion layer
        x = self.fusion_layer(s_cat, s_cont, o_x)

        return x

In [None]:
#| export
class MultInputWrapper(nn.Sequential):
    def __init__(self,
        arch,
        c_in:int=None, # number of input variables
        c_out:int=1, # number of output variables
        seq_len:int=None, # input sequence length
        d:tuple=None, # shape of the output tensor
        dls:TSDataLoaders=None, # TSDataLoaders object
        s_cat_idxs:list=None, # list of indices for static categorical variables
        s_cat_embeddings:list=None, # list of num_embeddings for each static categorical variable
        s_cat_embedding_dims:list=None, # list of embedding dimensions for each static categorical variable
        s_cont_idxs:list=None, # list of indices for static continuous variables
        o_cat_idxs:list=None, # list of indices for observed categorical variables
        o_cat_embeddings:list=None, # list of num_embeddings for each observed categorical variable
        o_cat_embedding_dims:list=None, # list of embedding dimensions for each observed categorical variable
        o_cont_idxs:list=None, # list of indices for observed continuous variables. All features not in s_cat_idxs, s_cont_idxs, o_cat_idxs are considered observed continuous variables.
        patch_len:int=None, # Number of time steps in each patch.
        patch_stride:int=None, # Stride of the patch.
        fusion_layers:list=128, # list of layer dimensions for the fusion MLP
        fusion_act:str='relu', # activation function for the fusion MLP
        fusion_dropout:float=0., # dropout probability for the fusion MLP
        fusion_use_bn:bool=True, # boolean indicating whether to use batch normalization in the fusion MLP
        custom_head=None, # custom head to replace the default head
        **kwargs
    ):

        # create backbone
        backbone = MultInputBackboneWrapper(arch, c_in=c_in, seq_len=seq_len, d=d, dls=dls, s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims,
                                            s_cont_idxs=s_cont_idxs, o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims, o_cont_idxs=o_cont_idxs,
                                            patch_len=patch_len, patch_stride=patch_stride, fusion_layers=fusion_layers, fusion_act=fusion_act, fusion_dropout=fusion_dropout, fusion_use_bn=fusion_use_bn, **kwargs)

        # create head
        self.head_nf = backbone.head_nf
        self.c_out = c_out
        self.seq_len = seq_len
        if custom_head:
            if isinstance(custom_head, nn.Module): head = custom_head
            else: head = custom_head(self.head_nf, c_out, seq_len, d=d)
        else:
            head = nn.Linear(self.head_nf, c_out)
        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))


In [None]:
from tsai.models.InceptionTimePlus import InceptionTimePlus

In [None]:
bs = 8
c_in = 6
c_out = 3
seq_len = 97
d = None

s_cat_idxs=2
s_cont_idxs=4
o_cat_idxs=[0, 3]
o_cont_idxs=None
s_cat_embeddings = 5
s_cat_embedding_dims = None
o_cat_embeddings = [7, 3]
o_cat_embedding_dims = [3, None]

fusion_layers = 128

t0 = torch.randint(0, 7, (bs, 1, seq_len)) # cat
t1 = torch.randn(bs, 1, seq_len)
t2 = torch.randint(0, 5, (bs, 1, seq_len)) # cat
t3 = torch.randint(0, 3, (bs, 1, seq_len)) # cat
t4 = torch.randn(bs, 1, seq_len)
t5 = torch.randn(bs, 1, seq_len)

t = torch.cat([t0, t1, t2, t3, t4, t5], 1).float().to(default_device())

patch_lens = [None, 5, 5, 5, 5]
patch_strides = [None, None, 1, 3, 5]
for patch_len, patch_stride in zip(patch_lens, patch_strides):
    for arch in ["InceptionTimePlus", InceptionTimePlus, "TSiTPlus"]:
        print(f"arch: {arch}, patch_len: {patch_len}, patch_stride: {patch_stride}")

        model = MultInputWrapper(
            arch=arch,
            c_in=c_in,
            c_out=c_out,
            seq_len=seq_len,
            d=d,
            s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims,
            s_cont_idxs=s_cont_idxs,
            o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims,
            o_cont_idxs=o_cont_idxs,
            patch_len=patch_len,
            patch_stride=patch_stride,
            fusion_layers=fusion_layers,
        ).to(default_device())

        test_eq(model(t).shape, (bs, c_out))

arch: InceptionTimePlus, patch_len: None, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: None, patch_stride: None
arch: TSiTPlus, patch_len: None, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: None
arch: TSiTPlus, patch_len: 5, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: 1
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 1
arch: TSiTPlus, patch_len: 5, patch_stride: 1
arch: InceptionTimePlus, patch_len: 5, patch_stride: 3
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 3
arch: TSiTPlus, patch_len: 5, patch_stride: 3
arch: InceptionTimePlus, patch_len: 5, patch_stride: 5
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 5
arch: TSiTPlus, patch_len: 5, patch_str

In [None]:
bs = 8
c_in = 6
c_out = 3
seq_len = 97
d = None

s_cat_idxs=None
s_cont_idxs=4
o_cat_idxs=[0, 3]
o_cont_idxs=None
s_cat_embeddings = None
s_cat_embedding_dims = None
o_cat_embeddings = [7, 3]
o_cat_embedding_dims = [3, None]

fusion_layers = 128

t0 = torch.randint(0, 7, (bs, 1, seq_len)) # cat
t1 = torch.randn(bs, 1, seq_len)
t2 = torch.randint(0, 5, (bs, 1, seq_len)) # cat
t3 = torch.randint(0, 3, (bs, 1, seq_len)) # cat
t4 = torch.randn(bs, 1, seq_len)
t5 = torch.randn(bs, 1, seq_len)

t = torch.cat([t0, t1, t2, t3, t4, t5], 1).float().to(default_device())

patch_lens = [None, 5, 5, 5, 5]
patch_strides = [None, None, 1, 3, 5]
for patch_len, patch_stride in zip(patch_lens, patch_strides):
    for arch in ["InceptionTimePlus", InceptionTimePlus, "TSiTPlus"]:
        print(f"arch: {arch}, patch_len: {patch_len}, patch_stride: {patch_stride}")

        model = MultInputWrapper(
            arch=arch,
            c_in=c_in,
            c_out=c_out,
            seq_len=seq_len,
            d=d,
            s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims,
            s_cont_idxs=s_cont_idxs,
            o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims,
            o_cont_idxs=o_cont_idxs,
            patch_len=patch_len,
            patch_stride=patch_stride,
            fusion_layers=fusion_layers,
        ).to(default_device())

        test_eq(model(t).shape, (bs, c_out))

arch: InceptionTimePlus, patch_len: None, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: None, patch_stride: None
arch: TSiTPlus, patch_len: None, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: None
arch: TSiTPlus, patch_len: 5, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: 1
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 1
arch: TSiTPlus, patch_len: 5, patch_stride: 1
arch: InceptionTimePlus, patch_len: 5, patch_stride: 3
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 3
arch: TSiTPlus, patch_len: 5, patch_stride: 3
arch: InceptionTimePlus, patch_len: 5, patch_stride: 5
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 5
arch: TSiTPlus, patch_len: 5, patch_str

In [None]:
bs = 8
c_in = 6
c_out = 3
seq_len = 97
d = None

s_cat_idxs=2
s_cont_idxs=4
o_cat_idxs=None
o_cont_idxs=None
s_cat_embeddings = 5
s_cat_embedding_dims = None
o_cat_embeddings = None
o_cat_embedding_dims = None

fusion_layers = 128

t0 = torch.randint(0, 7, (bs, 1, seq_len)) # cat
t1 = torch.randn(bs, 1, seq_len)
t2 = torch.randint(0, 5, (bs, 1, seq_len)) # cat
t3 = torch.randint(0, 3, (bs, 1, seq_len)) # cat
t4 = torch.randn(bs, 1, seq_len)
t5 = torch.randn(bs, 1, seq_len)

t = torch.cat([t0, t1, t2, t3, t4, t5], 1).float().to(default_device())

patch_lens = [None, 5, 5, 5, 5]
patch_strides = [None, None, 1, 3, 5]
for patch_len, patch_stride in zip(patch_lens, patch_strides):
    for arch in ["InceptionTimePlus", InceptionTimePlus, "TSiTPlus"]:
        print(f"arch: {arch}, patch_len: {patch_len}, patch_stride: {patch_stride}")

        model = MultInputWrapper(
            arch=arch,
            c_in=c_in,
            c_out=c_out,
            seq_len=seq_len,
            d=d,
            s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims,
            s_cont_idxs=s_cont_idxs,
            o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims,
            o_cont_idxs=o_cont_idxs,
            patch_len=patch_len,
            patch_stride=patch_stride,
            fusion_layers=fusion_layers,
        ).to(default_device())

        test_eq(model(t).shape, (bs, c_out))

arch: InceptionTimePlus, patch_len: None, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: None, patch_stride: None
arch: TSiTPlus, patch_len: None, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: None
arch: TSiTPlus, patch_len: 5, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: 1
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 1
arch: TSiTPlus, patch_len: 5, patch_stride: 1
arch: InceptionTimePlus, patch_len: 5, patch_stride: 3
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 3
arch: TSiTPlus, patch_len: 5, patch_stride: 3
arch: InceptionTimePlus, patch_len: 5, patch_stride: 5
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 5
arch: TSiTPlus, patch_len: 5, patch_str

In [None]:
bs = 8
c_in = 6
c_out = 3
seq_len = 97
d = None

s_cat_idxs=None
s_cont_idxs=None
o_cat_idxs=None
o_cont_idxs=None
s_cat_embeddings = None
s_cat_embedding_dims = None
o_cat_embeddings = None
o_cat_embedding_dims = None

fusion_layers = 128

t0 = torch.randint(0, 7, (bs, 1, seq_len)) # cat
t1 = torch.randn(bs, 1, seq_len)
t2 = torch.randint(0, 5, (bs, 1, seq_len)) # cat
t3 = torch.randint(0, 3, (bs, 1, seq_len)) # cat
t4 = torch.randn(bs, 1, seq_len)
t5 = torch.randn(bs, 1, seq_len)

t = torch.cat([t0, t1, t2, t3, t4, t5], 1).float().to(default_device())

patch_lens = [None, 5, 5, 5, 5]
patch_strides = [None, None, 1, 3, 5]
for patch_len, patch_stride in zip(patch_lens, patch_strides):
    for arch in ["InceptionTimePlus", InceptionTimePlus, "TSiTPlus"]:
        print(f"arch: {arch}, patch_len: {patch_len}, patch_stride: {patch_stride}")

        model = MultInputWrapper(
            arch=arch,
            c_in=c_in,
            c_out=c_out,
            seq_len=seq_len,
            d=d,
            s_cat_idxs=s_cat_idxs, s_cat_embeddings=s_cat_embeddings, s_cat_embedding_dims=s_cat_embedding_dims,
            s_cont_idxs=s_cont_idxs,
            o_cat_idxs=o_cat_idxs, o_cat_embeddings=o_cat_embeddings, o_cat_embedding_dims=o_cat_embedding_dims,
            o_cont_idxs=o_cont_idxs,
            patch_len=patch_len,
            patch_stride=patch_stride,
            fusion_layers=fusion_layers,
        ).to(default_device())

        test_eq(model(t).shape, (bs, c_out))

arch: InceptionTimePlus, patch_len: None, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: None, patch_stride: None
arch: TSiTPlus, patch_len: None, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: None
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: None
arch: TSiTPlus, patch_len: 5, patch_stride: None
arch: InceptionTimePlus, patch_len: 5, patch_stride: 1
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 1
arch: TSiTPlus, patch_len: 5, patch_stride: 1
arch: InceptionTimePlus, patch_len: 5, patch_stride: 3
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 3
arch: TSiTPlus, patch_len: 5, patch_stride: 3
arch: InceptionTimePlus, patch_len: 5, patch_stride: 5
arch: <class 'tsai.models.InceptionTimePlus.InceptionTimePlus'>, patch_len: 5, patch_stride: 5
arch: TSiTPlus, patch_len: 5, patch_str

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/077_models.multimodal.ipynb saved at 2024-02-10 21:58:47
Correct notebook to script conversion! 😃
Saturday 10/02/24 21:58:50 CET
