In [None]:
#|default_exp models.MultiRocketPlus

# MultiRocketPlus

>MultiRocket: Multiple pooling operators and transformations for fast and effective time series classification.

This is a Pytorch implementation of MultiRocket developed by Malcolm McLean and Ignacio Oguiza based on:

Tan, C. W., Dempster, A., Bergmeir, C., & Webb, G. I. (2022). MultiRocket: multiple pooling operators and transformations for fast and effective time series classification. Data Mining and Knowledge Discovery, 36(5), 1623-1646.

Original paper: https://link.springer.com/article/10.1007/s10618-022-00844-1

Original repository:  https://github.com/ChangWeiTan/MultiRocket

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from collections import OrderedDict
import itertools
from tsai.models.layers import rocket_nd_head

In [None]:
#| export
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)

In [None]:
#| export
def _LPVV(o, dim=2):
    "Longest stretch of positive values along a dimension(-1, 1)"

    seq_len = o.shape[dim]
    binary_tensor = (o > 0).float()

    diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),
                      binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)

    groups = (diff > 0).cumsum(dim)

    # Ensure groups are within valid index bounds
    groups = groups * binary_tensor.long()
    valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))

    counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)

    longest_stretch = counts.max(dim)[0]

    return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)

def _MPV(o, dim=2):
    "Mean of Positive Values (any positive value)"
    o = torch.where(o > 0, o, torch.nan)
    o = torch.nanmean(o, dim)
    return torch.nan_to_num(o)

def _RSPV(o, dim=2):
    "Relative Sum of Positive Values (-1, 1)"
    o_sum = torch.clamp_min(torch.abs(o).sum(dim), 1e-8)
    o_pos_sum = torch.nansum(F.relu(o), dim)
    return (o_pos_sum / o_sum) * 2 - 1

def _MIPV(o, o_pos, dim=2):
    "Mean of Indices of Positive Values (-1, 1)"
    seq_len = o.shape[dim]
    o_arange_shape = [1] * o_pos.ndim
    o_arange_shape[dim] = -1
    o_arange = torch.arange(o_pos.shape[dim], device=o.device).reshape(o_arange_shape)
    o = torch.where(o_pos, o_arange, torch.nan)
    o = torch.nanmean(o, dim)
    return (torch.nan_to_num(o) / seq_len) * 2 - 1

def _PPV(o_pos, dim=2):
    "Proportion of Positive Values (-1, 1)"
    return (o_pos).float().mean(dim) * 2 - 1

In [None]:
from tsai.imports import default_device

In [None]:
o = torch.rand(2, 3, 5, 4).to(default_device()) - .3
print(o)

output = _LPVV(o, dim=2)
print(output)  # Should print: torch.Size([2, 3, 4])

tensor([[[[ 0.5644, -0.0509, -0.0390,  0.4091],
          [ 0.0517, -0.1471,  0.6458,  0.5593],
          [ 0.4516, -0.0821,  0.1271,  0.0592],
          [ 0.4151,  0.4376,  0.0763,  0.3780],
          [ 0.2653, -0.1817,  0.0156,  0.4993]],

         [[-0.0779,  0.0858,  0.1982,  0.3224],
          [ 0.1130,  0.0714, -0.1779,  0.5360],
          [-0.1848, -0.2270, -0.0925, -0.1217],
          [ 0.2820, -0.0205, -0.2777,  0.3755],
          [-0.2490,  0.2613,  0.4237,  0.4534]],

         [[-0.0162,  0.6368,  0.0016,  0.1467],
          [ 0.6035, -0.1365,  0.6930,  0.6943],
          [ 0.2790,  0.3818, -0.0731,  0.0167],
          [ 0.6442,  0.3443,  0.4829, -0.0944],
          [ 0.2932,  0.6952,  0.5541,  0.5946]]],


        [[[ 0.6757,  0.5740,  0.3071,  0.4400],
          [-0.2344, -0.1056,  0.4773,  0.2432],
          [ 0.2595, -0.1528, -0.0866,  0.6201],
          [ 0.0657,  0.1220,  0.4849,  0.4254],
          [ 0.3399, -0.1609,  0.3465,  0.2389]],

         [[-0.0765,  0.0516,  

In [None]:
output = _MPV(o, dim=2)
print(output)  # Should print: torch.Size([2, 3, 4])

tensor([[[0.3496, 0.4376, 0.2162, 0.3810],
         [0.1975, 0.1395, 0.3109, 0.4218],
         [0.4550, 0.5145, 0.4329, 0.3631]],

        [[0.3352, 0.3480, 0.4040, 0.3935],
         [0.5023, 0.3078, 0.3968, 0.5221],
         [0.3679, 0.3380, 0.2460, 0.4079]]], device='mps:0')


In [None]:
output = _RSPV(o, dim=2)
print(output)  # Should print: torch.Size([2, 3, 4])

tensor([[[ 1.0000, -0.0270,  0.9138,  1.0000],
         [-0.1286,  0.2568,  0.0630,  0.8654],
         [ 0.9823,  0.8756,  0.9190,  0.8779]],

        [[ 0.7024,  0.2482,  0.8983,  1.0000],
         [ 0.6168,  0.2392,  0.8931,  0.9715],
         [ 0.5517,  0.8133,  0.7065,  0.8244]]], device='mps:0')


In [None]:
output = _PPV(o, dim=2)
print(output)  # Should print: torch.Size([2, 3, 4])

tensor([[[-0.3007, -1.0097, -0.6697, -0.2381],
         [-1.0466, -0.9316, -0.9705, -0.3738],
         [-0.2786, -0.2314, -0.3366, -0.4569]],

        [[-0.5574, -0.8893, -0.3883, -0.2130],
         [-0.5401, -0.8574, -0.4009, -0.1767],
         [-0.6861, -0.5149, -0.7555, -0.4102]]], device='mps:0')


In [None]:
#|export
class MultiRocketFeaturesPlus(nn.Module):
    fitting = False

    def __init__(self, c_in, seq_len, num_features=10_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=9, max_num_kernels=84, diff=False):
        super(MultiRocketFeaturesPlus, self).__init__()

        self.c_in, self.seq_len = c_in, seq_len
        self.kernel_size, self.max_num_channels = kernel_size, max_num_channels

        # Kernels
        indices, pos_values = self.get_indices(kernel_size, max_num_kernels)
        self.num_kernels = len(indices)
        kernels = (-torch.ones(self.num_kernels, 1, self.kernel_size)).scatter_(2, indices, pos_values)
        self.indices = indices
        self.kernels = nn.Parameter(kernels.repeat(c_in, 1, 1), requires_grad=False)
        num_features = num_features // 4
        self.num_features = num_features // self.num_kernels * self.num_kernels
        self.max_dilations_per_kernel = max_dilations_per_kernel

        # Dilations
        self.set_dilations(seq_len)

        # Channel combinations (multivariate)
        if c_in > 1:
            self.set_channel_combinations(c_in, max_num_channels)

        # Bias
        for i in range(self.num_dilations):
            self.register_buffer(f'biases_{i}', torch.empty(
                (self.num_kernels, self.num_features_per_dilation[i])))
        self.register_buffer('prefit', torch.BoolTensor([False]))

    def forward(self, x):

        _features = []
        for i, (dilation, padding) in enumerate(zip(self.dilations, self.padding)):
            _padding1 = i % 2

            # Convolution
            C = F.conv1d(x, self.kernels, padding=padding,
                         dilation=dilation, groups=self.c_in)
            if self.c_in > 1:  # multivariate
                C = C.reshape(x.shape[0], self.c_in, self.num_kernels, -1)
                channel_combination = getattr(
                    self, f'channel_combinations_{i}')
                C = torch.mul(C, channel_combination)
                C = C.sum(1)

            # Bias
            if not self.prefit or self.fitting:
                num_features_this_dilation = self.num_features_per_dilation[i]
                bias_this_dilation = self.get_bias(
                    C, num_features_this_dilation)
                setattr(self, f'biases_{i}', bias_this_dilation)
                if self.fitting:
                    if i < self.num_dilations - 1:
                        continue
                    else:
                        self.prefit = torch.BoolTensor([True])
                        return
                elif i == self.num_dilations - 1:
                    self.prefit = torch.BoolTensor([True])
            else:
                bias_this_dilation = getattr(self, f'biases_{i}')

            # Features
            _features.append(self.apply_pooling_ops(
                C[:, _padding1::2], bias_this_dilation[_padding1::2]))
            _features.append(self.apply_pooling_ops(
                C[:, 1-_padding1::2, padding:-padding], bias_this_dilation[1-_padding1::2]))

        return torch.cat(_features, dim=1)

    def fit(self, X, chunksize=None):
        num_samples = X.shape[0]
        if chunksize is None:
            chunksize = min(num_samples, self.num_dilations * self.num_kernels)
        else:
            chunksize = min(num_samples, chunksize)
        idxs = np.random.choice(num_samples, chunksize, False)
        self.fitting = True
        if isinstance(X, np.ndarray):
            self(torch.from_numpy(X[idxs]).to(self.kernels.device))
        else:
            self(X[idxs].to(self.kernels.device))
        self.fitting = False

    def apply_pooling_ops(self, C, bias):
        C = C.unsqueeze(-1)
        bias = bias.view(1, bias.shape[0], 1, bias.shape[1])
        pos_vals = (C > bias)
        ppv = _PPV(pos_vals).flatten(1)
        mpv = _MPV(C - bias).flatten(1)
        # rspv = _RSPV(C - bias).flatten(1)
        mipv = _MIPV(C, pos_vals).flatten(1)
        lspv = _LPVV(pos_vals).flatten(1)
        return torch.cat((ppv, mpv, mipv, lspv), dim=1)
        return torch.cat((ppv, rspv, mipv, lspv), dim=1)

    def set_dilations(self, input_length):
        num_features_per_kernel = self.num_features // self.num_kernels
        true_max_dilations_per_kernel = min(
            num_features_per_kernel, self.max_dilations_per_kernel)
        multiplier = num_features_per_kernel / true_max_dilations_per_kernel
        max_exponent = np.log2((input_length - 1) / (self.kernel_size - 1))
        dilations, num_features_per_dilation = \
            np.unique(np.logspace(0, max_exponent, true_max_dilations_per_kernel, base=2).astype(
                np.int32), return_counts=True)
        num_features_per_dilation = (
            num_features_per_dilation * multiplier).astype(np.int32)
        remainder = num_features_per_kernel - num_features_per_dilation.sum()
        i = 0
        while remainder > 0:
            num_features_per_dilation[i] += 1
            remainder -= 1
            i = (i + 1) % len(num_features_per_dilation)
        self.num_features_per_dilation = num_features_per_dilation
        self.num_dilations = len(dilations)
        self.dilations = dilations
        self.padding = []
        for i, dilation in enumerate(dilations):
            self.padding.append((((self.kernel_size - 1) * dilation) // 2))

    def set_channel_combinations(self, num_channels, max_num_channels):
        num_combinations = self.num_kernels * self.num_dilations
        if max_num_channels:
            max_num_channels = min(num_channels, max_num_channels)
        else:
            max_num_channels = num_channels
        max_exponent_channels = np.log2(max_num_channels + 1)
        num_channels_per_combination = (
            2 ** np.random.uniform(0, max_exponent_channels, num_combinations)).astype(np.int32)
        self.num_channels_per_combination = num_channels_per_combination
        channel_combinations = torch.zeros(
            (1, num_channels, num_combinations, 1))
        for i in range(num_combinations):
            channel_combinations[:, np.random.choice(
                num_channels, num_channels_per_combination[i], False), i] = 1
        channel_combinations = torch.split(
            channel_combinations, self.num_kernels, 2)  # split by dilation
        for i, channel_combination in enumerate(channel_combinations):
            self.register_buffer(
                f'channel_combinations_{i}', channel_combination)  # per dilation

    def get_quantiles(self, n):
        return torch.tensor([(_ * ((np.sqrt(5) + 1) / 2)) % 1 for _ in range(1, n + 1)]).float()

    def get_bias(self, C, num_features_this_dilation):
        isp = torch.randint(C.shape[0], (self.num_kernels,))
        samples = C[isp].diagonal().T
        biases = torch.quantile(samples, self.get_quantiles(
            num_features_this_dilation).to(C.device), dim=1).T
        return biases

    def get_indices(self, kernel_size, max_num_kernels):
        num_pos_values = math.ceil(kernel_size / 3)
        num_neg_values = kernel_size - num_pos_values
        pos_values = num_neg_values / num_pos_values
        if kernel_size > 9:
            random_kernels = [np.sort(np.random.choice(kernel_size, num_pos_values, False)).reshape(
                1, -1) for _ in range(max_num_kernels)]
            indices = torch.from_numpy(
                np.concatenate(random_kernels, 0)).unsqueeze(1)
        else:
            indices = torch.LongTensor(list(itertools.combinations(
                np.arange(kernel_size), num_pos_values))).unsqueeze(1)
            if max_num_kernels and len(indices) > max_num_kernels:
                indices = indices[np.sort(np.random.choice(
                    len(indices), max_num_kernels, False))]
        return indices, pos_values

In [None]:
#| export
class MultiRocketBackbonePlus(nn.Module):
    def __init__(self, c_in, seq_len, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84, use_diff=True):
        super(MultiRocketBackbonePlus, self).__init__()

        num_features_per_branch = num_features // (1 + use_diff)
        self.branch_x = MultiRocketFeaturesPlus(c_in, seq_len, num_features=num_features_per_branch, max_dilations_per_kernel=max_dilations_per_kernel,
                                                kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels)
        if use_diff:
            self.branch_x_diff = MultiRocketFeaturesPlus(c_in, seq_len - 1, num_features=num_features_per_branch, max_dilations_per_kernel=max_dilations_per_kernel,
                                                         kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels)
        if use_diff:
            self.num_features = (self.branch_x.num_features + self.branch_x_diff.num_features) * 4 # 4 types of features
        else:
            self.num_features = self.branch_x.num_features * 4
        self.use_diff = use_diff

    def forward(self, x):
        if self.use_diff:
            x_features = self.branch_x(x)
            x_diff_features = self.branch_x(torch.diff(x))
            output = torch.cat([x_features, x_diff_features], dim=-1)
            return output
        else:
            output = self.branch_x(x)
            return output

In [None]:
#| export
class MultiRocketPlus(nn.Sequential):

    def __init__(self, c_in, c_out, seq_len, d=None, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84,
                 use_bn=True, fc_dropout=0, custom_head=None, zero_init=True, use_diff=True):

        # Backbone
        backbone = MultiRocketBackbonePlus(c_in, seq_len, num_features=num_features, max_dilations_per_kernel=max_dilations_per_kernel,
                                          kernel_size=kernel_size, max_num_channels=max_num_channels, max_num_kernels=max_num_kernels, use_diff=use_diff)
        num_features = backbone.num_features

        # Head
        self.head_nf = num_features
        if custom_head is not None:
            if isinstance(custom_head, nn.Module): head = custom_head
            else: head = custom_head(self.head_nf, c_out, 1)
        elif d is not None:
            head = rocket_nd_head(num_features, c_out, seq_len=None, d=d, use_bn=use_bn, fc_dropout=fc_dropout, zero_init=zero_init)
        else:
            layers = [Flatten()]
            if use_bn:
                layers += [nn.BatchNorm1d(num_features)]
            if fc_dropout:
                layers += [nn.Dropout(fc_dropout)]
            linear = nn.Linear(num_features, c_out)
            if zero_init:
                nn.init.constant_(linear.weight.data, 0)
                nn.init.constant_(linear.bias.data, 0)
            layers += [linear]
            head = nn.Sequential(*layers)

        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))

MultiRocket = MultiRocketPlus

In [None]:
from tsai.imports import default_device

In [None]:
xb = torch.randn(16, 5, 20).to(default_device())
yb = torch.randint(0, 3, (16, 20)).to(default_device())

model = MultiRocketPlus(5, 3, 20, d=None, use_diff=True).to(default_device())
output = model(xb)
assert output.shape == (16, 3)
output.shape

torch.Size([16, 3])

In [None]:
xb = torch.randn(16, 5, 20).to(default_device())
yb = torch.randint(0, 3, (16, 20)).to(default_device())

model = MultiRocketPlus(5, 3, 20, d=None, use_diff=False).to(default_device())
output = model(xb)
assert output.shape == (16, 3)
output.shape

torch.Size([16, 3])

In [None]:
xb = torch.randn(16, 5, 20).to(default_device())
yb = torch.randint(0, 3, (16, 5, 20)).to(default_device())

model = MultiRocketPlus(5, 3, 20, d=20, use_diff=True).to(default_device())
output = model(xb)
assert output.shape == (16, 20, 3)
output.shape

torch.Size([16, 20, 3])

In [None]:
#|eval: false
#|hide
from tsai.export import get_nb_name; nb_name = get_nb_name(locals())
from tsai.imports import create_scripts; create_scripts(nb_name)

<IPython.core.display.Javascript object>

/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 10:53:13
Correct notebook to script conversion! 😃
Sunday 11/02/24 10:53:16 CET
