In [None]:
# default_exp model.blocks

In [None]:
# export
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from profetorch.data.data import convert_date

# Blocks
> Blocks that compose the `Model`.

For instance the median is composed of a Trend and Seasonality component (which in turn is composed of Fourier components).

## Linear Models:

In [None]:
# export
class Trend(nn.Module):
    """
    Broken Trend model, with breakpoints as defined by user.
    """
    def __init__(self, breakpoints:int=None, moment=None):
        super().__init__()
        self.init_layer = nn.Linear(1,1) # first linear bit
            
        if breakpoints is not None:
            if isinstance(breakpoints, int):
                # range = moment['t_range'][1] - moment['t_range'][0]
                # breakpoints = torch.rand(breakpoint)*range + moment['t_range'][0]
                if breakpoints > 0:
                    breakpoints = np.linspace(*moment['t_range'], breakpoints+1, endpoint=False)[1:]
                else:
                    breakpoints = None
            else:
                breakpoints = convert_date(breakpoints)
            # create deltas which is how the gradient will change
            deltas = torch.zeros(len(breakpoints)) # initialisation
            self.deltas = nn.Parameter(deltas) # make it a parameter
        
        self.bpoints = breakpoints
        
    def __copy2array(self):
        """
        Saves parameters into wb
        """
        # extract gradient and bias
        w = self.init_layer.weight
        b = self.init_layer.bias
        self.params = [[w,b]] # save it to buffer
        if self.bpoints:
            for d, x1 in zip(self.deltas, self.bpoints):
                y1 = w *x1 + b # find the endpoint of line segment (x1, y1)
                w = w + d # add on the delta to gradient 
                b = y1 - w * x1 # find new bias of line segment 
                self.params.append([w,b]) # add to buffer

        # create buffer
        self.wb = torch.zeros(len(self.params), len(self.params[0]))
        for i in range(self.wb.shape[0]):
            for j in range(self.wb.shape[1]):
                self.wb[i,j] = self.params[i][j]
        
    def forward(self, t:torch.Tensor):
        if self.bpoints is not None:
            self.__copy2array() # copy across parameters into matrix
            # get the line segment area (x_sec) for each x
            x_sec = t >= self.bpoints
            x_sec = x_sec.sum(1)
            
            # get final prediction y = mx +b for relevant section
            return t*self.wb[x_sec][:,:1] + self.wb[x_sec][:,1:]
        
        else:
            return self.init_layer(t)

In [None]:
# export
class LinearX(nn.Module):
    """
    Linear model of non-time based inputs.
    """
    def __init__(self, dims):
        super().__init__()
        if dims > 0:
            self.linear = nn.Linear(dims, 1, bias=False)
        
    def forward(self, x):
        if x is not None:
            return self.linear(x)
        else:
            return 0

## Periodic Functions

In [None]:
# export
_TWOPI = 2*np.pi

class FourierModel(nn.Module):
    """
    Block that outputs sin's and cos' as basis functions.
    """
    def __init__(self, p:float=365.25, scale:float=1, n:int=7):
        super().__init__()
        self.np = [(i+1, p/scale) for i in range(n)]
        if n > 0:
            self.linear = nn.Linear(n * 2, 1, bias=False)
            
    def forward(self, t:torch.Tensor):
        if len(self.np) > 0:
            cos = [torch.cos(_TWOPI * n * t / p) for n,p in self.np]
            sin = [torch.sin(_TWOPI * n * t / p) for n,p in self.np]

            return self.linear(torch.cat(cos + sin, dim=1))
        
        else:
            return 0
    
    def plot(self):
        if self.n > 0:
            t = torch.linspace(0, self.p, steps=100)
            y = self.forward(t[:,None])
            plt.figure(figsize=(12,5))
            plt.plot(t.cpu().numpy(), y.detach().cpu().numpy())
            plt.show()


class Seasonal(nn.Module):
    """
    Combination of Fourier Blocks that gives weekly, monthly and 
    yearly seasonality.
    """
    def __init__(self, y_n=7, m_n=5, w_n=0, 
                 y_p=365.25, m_p=30.5, w_p=7, scale=1):
        super().__init__()
        self.yearly = FourierModel(y_p, scale, y_n) # , w[:,:idxs[0]]
        self.monthly = FourierModel(m_p, scale, m_n) # w[:,idxs[0]:idxs[1]]
        self.weekly = FourierModel(w_p, scale, w_n) # w[:,idxs[1]:idxs[2]]
        
    def forward(self, t):
        return self.yearly(t) + self.monthly(t) + self.weekly(t)

## Holiday Functions
Can use `LinearX` instead. For instance `df['June'] = df.ds.dt.month==6`.

In [None]:
# export
class Holiday(nn.Module):
    def __init__(self, holiday, repeat_every=365, mean=0, scale=1):
        super().__init__()
        self.holiday = (holiday - mean) / scale
        self.repeat_every = repeat_every / scale
        self.w = nn.Parameter(torch.zeros(1)+0.05)
        
    def forward(self, t):
        rem = torch.remainder(t - self.holiday, self.repeat_every)
        return (rem == 0).float() * self.w


class HolidayRange(nn.Module):
    def __init__(self, holidays):
        """
        holidays: list of lists containing lower and upper bound of hols
        """
        super().__init__()
        self.holidays = holidays
        self.w = nn.Parameter(torch.zeros(1)+0.05)
        
    def forward(self, t):
        bounded = [(l<=t) & (t<=u) for l,u in self.holidays]
        return sum(bounded).float()*self.w

## Miscellaneous

In [None]:
# export
class Squasher(nn.Module):
    """
    Squashes output to lie beween `high` and `low`.
    """
    def __init__(self, low=None, high=None, mean=0, sd=1, alpha=0.01):
        super().__init__()
        if low is not None:
            low = (low - mean) / sd
        if high is not None:
            high = (high - mean) / sd
        self.L, self.H, self.alpha = low, high, alpha
        
    def forward(self, t): 
        if self.L is not None:
            t[t < self.L] = self.alpha * (t[t < self.L] - self.L) + self.L
        if self.H is not None:
            t[t > self.H] = self.alpha * (t[t > self.H] - self.H) + self.H
        return t
    

class RandomWalk(nn.Module):
    def __init__(self, n, breaks):
        super().__init__()
        self.w = nn.Parameter(torch.randn(n,1))
        self.breaks = breaks
        
    def forward(self, x):
        w = F.softplus(self.w.cumsum(0))
        x_sec = get_section(x, self.breaks)
        return w[x_sec]

In [None]:
# export
class DefaultModel(nn.Module):
    """
    Sum of Linear Trend, Seasonality and squashed.
    """
    def __init__(self, moments, breakpoints=None, y_n=7, m_n=5, w_n=0, l=None, h=None):
        super().__init__()
        if 'x' in moments:
            dims = moments['x'][0].shape[1]
        else:
            dims = 0

        self.trend = Trend(breakpoints, moments)
        self.seasonal = Seasonal(y_n, m_n, w_n, scale=moments['t'][1])
        self.linear = LinearX(dims)
        self.squash = Squasher(l, h, *moments['y'])

    def forward(self, t, x=None):
        prediction = self.seasonal(t) + self.trend(t) + self.linear(x)
        prediction = self.squash(prediction)
        return prediction
    
class DefaultQModel(nn.Module):
    """
    Same as DefaultModel but with multiple outputs corresponding to quantiles
    """
    def __init__(self, moments, breakpoints=None, y_n=7, m_n=5, w_n=0, l=None, h=None, quantiles=[0.05, 0.5, 0.95]):
        super().__init__()
        assert 0.5 in quantiles, f'0.5 needs to be in quantiles. Provided {quantiles} as quantiles.'
        self.idx = quantiles.index(0.5)
        signs = [q-0.5 for q in quantiles]
        self.signs = torch.Tensor([-1 if s<0 else 1 for i,s in enumerate(signs) if i != self.idx])[None,:]
        self.idxs = [i for i in range(len(quantiles)) if i != self.idx]
        median_args = {'y_n': y_n, 'm_n': m_n, 'w_n': w_n}
        other_args = {'y_n': 0, 'm_n': 0, 'w_n': 0}
        args = [median_args if q==0.5 else other_args for q in quantiles]
        self.models = nn.ModuleList([DefaultModel(moments, breakpoints, **arg) for arg in args])
        self.squash = Squasher(l, h, *moments['y'])
        
    def forward(self, t, x=None):
        prediction = torch.cat([m(t,x) for m in self.models], -1)
        median = prediction[:, [self.idx]]
        prediction[:, self.idxs] = median + F.softplus(prediction[:, self.idxs]) * self.signs
            
        return self.squash(prediction)

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 99_index.ipynb.
Converted blocks.ipynb.
Converted callbacks.ipynb.
Converted data.ipynb.
Converted losses.ipynb.
Converted model.ipynb.
