In [1]:
IN_COLAB = True if 'google.colab' in str(get_ipython()) else False

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils_data

In [3]:
if IN_COLAB:
  !pip install -q pytorch-lightning

import pytorch_lightning as pl

[K     |████████████████████████████████| 563kB 9.4MB/s 
[K     |████████████████████████████████| 92kB 8.1MB/s 
[K     |████████████████████████████████| 829kB 17.6MB/s 
[K     |████████████████████████████████| 276kB 39.2MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone


In [4]:
# ALL GLOBAL VARIABLES

GLOBAL_DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
GLOBAL_SEED = 42
PADDING_CONST = 0

In [59]:
def compute_integral_mc(cond_lam, mc_type_lam, time, src_padding_mask, type_encoding, alpha=-0.1, n_samples=100):
    """
    Compute integral using Monte Carlo integration.
    """
    dt = (time[:, 1:] - time[:, :-1]) * (~src_padding_mask[:, 1:])

    # compute u used in eq. (9) from the paper
    u = dt.unsqueeze(2) * torch.rand([*dt.size(), n_samples], device=cond_lam.device) / (time[:, :-1] + 1).unsqueeze(2)

    # compute lambda(u)
    softplus = nn.Softplus(threshold=10)
    mc_cond_lam = softplus( alpha * u + mc_type_lam ).sum(dim=2) / n_samples

    integral = dt * mc_cond_lam

    return integral

In [173]:
class HawkesTransformer(nn.Module):

    def __init__(self, n_event_types, device, d_model=512, n_heads=8, n_layers=6, dropout=0.1):
        """
        Input parameters:
          n_event_types (int) - number of event types in the data,
          d_model (int) - size of model's latent dimension,
          n_heads (int) - number of heads in the Multihead Attention module,
          n_layers (int) - number of Transformer encoder layers,
          dropout (float) - dropout rate
        """
        super(HawkesTransformer, self).__init__()

        self.d_model = d_model
        self.device = device

        # initialize div term for temporal encoding
        self.init_temporal_encoding()

        # event type embedding
        self.event_embedding = nn.Embedding(n_event_types + 1, d_model, padding_idx=PADDING_CONST)

        # transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads)
        layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.transformer_layers = nn.TransformerEncoder(encoder_layer, n_layers, norm=layer_norm)

        # linear transformation of hidden states ("history" and "base" terms in eq.(6) of the THP paper)
        self.transform = nn.Linear(d_model, n_event_types)
        self.softplus = nn.Softplus(threshold=10)

        # output prediction layers
        self.time_predictor  = nn.Linear(d_model, 1)
        self.event_predictor = nn.Linear(d_model, n_event_types)

        # small constant
        self.eps = torch.tensor([1e-8], device=self.device)

    def generate_subsequent_mask(self, seq):
        """
        Function to generate masking for the subsequent information in the sequences (masked self-attention).
        Input:
          seq (B, S, F) - batch of sequences.
        """
        bs, ls = seq.size()
        subsequent_mask = torch.triu( torch.ones(ls, ls, device=self.device, dtype=torch.bool), diagonal=1 )
        
        return subsequent_mask
    
    def generate_key_padding_mask(self, seq):
        """
        Masking the padded part of the sequence.
        Input:
          seq (B, S, F) - batch of sequences.
        """
        ls = seq.size(1)
        padding_mask = seq.eq(PADDING_CONST)

        return padding_mask

    def init_temporal_encoding(self):
        """
        Initializing the internal temporal encoding tensors.
        """
        encoding_constant = torch.tensor(10000.0)

        # for better numerical stability
        self.te_div_term = torch.exp(2.0 * (torch.arange(0, self.d_model) // 2) * -torch.log(encoding_constant) / self.d_model).to(self.device)
  
    def temporal_encoding(self, t, non_padded_mask):
        """
        Function to perform the temporal encoding on input timestamps.
        Input:
          t (B, S) - batch of timestamp sequences,
          non_padded_mask (B, S) - binary mask indicating whether element is a padding (True) or not (False)
        Output:
          x (B, S, d_model) - raw model output,
          lam (B, S, F) - intensity function,
          time_pred (B, S) - timestamp prediction for the next event,
          event_pred (B, S, n_event_types) - probabilities of event types
        """
        temporal_enc = t.unsqueeze(-1) * self.te_div_term

        temporal_enc[:, :, 0::2] = torch.sin(temporal_enc[:, :, 0::2])
        temporal_enc[:, :, 1::2] = torch.cos(temporal_enc[:, :, 1::2])

        return temporal_enc * non_padded_mask.unsqueeze(-1)
    
    def forward(self, input_seq):
        """
        Input:
          input_seq (B, S, F) - input sequence of size (batch size, sequence length, features)
        """
        bs, ls, nf = input_seq.size()

        # generate masks
        src_key_padding_mask = self.generate_key_padding_mask(input_seq[:,:,1])
        src_non_padded_mask = ~src_key_padding_mask
        src_mask = self.generate_subsequent_mask(input_seq[:,:,1])

        # perform encodings
        temp_enc  = self.temporal_encoding(input_seq[:,:,0], src_non_padded_mask)
        event_enc = self.event_embedding(input_seq[:,:,1])

        # make pass through transformer encoder layers
        x = event_enc + temp_enc
        x = self.transformer_layers(x.permute(1, 0, 2), mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        x = x.permute(1, 0, 2)

        # calculate type-specific intensity function
        lam = self.softplus(self.transform(x))

        # make predictions
        time_pred  = self.time_predictor(x).squeeze(2) * src_non_padded_mask
        event_pred = self.event_predictor(x) * src_non_padded_mask.unsqueeze(-1)

        return x, lam, (time_pred, event_pred)
    

    def loss_function(self, x, lam, time_pred, event_pred, tgt_seq, alpha=-0.1):
        """
        Input:
          x (B, S, d_model) - raw network output,
          lam (B, S, F) - type specific intensity functions,
          tgt_seq (B, S, F) - original sequence of size (batch size, sequence length, features),
          alpha (float) - weight coefficient for "current" influence in eq. (6) from the paper
        """

        bs, ls, nf = tgt_seq.size()
        type_encoding = torch.zeros(bs, ls, nf, device=self.device)
        for k in range(nf):
            type_encoding[:, :, k] = (tgt_seq[:, :, 1] == k + 1).bool().to(self.device)
        src_padding_mask = self.generate_key_padding_mask(tgt_seq[:,:,1])

        # compute conditional intensity function
        cond_lam = (lam * type_encoding).sum(dim=2)

        # compute event log-likelihood
        event_part = (cond_lam + self.eps).masked_fill_(src_padding_mask, 1.0).log()
        event_part = event_part.sum(dim=1)

        # compute non-event log-likelihood

        # compute lambda for (t_{j+1}) for M-C integration
        mc_type_lam = self.transform(x[:, 1:, :])
        mc_type_lam = (mc_type_lam * type_encoding[:, 1:, :]).sum(dim=2, keepdim=True)

        non_event_part = compute_integral_mc(cond_lam, mc_type_lam, tgt_seq[:,:,0], src_padding_mask, type_encoding, alpha).sum(dim=1)

        # compute total log-likelihood
        log_likelihood = event_part - non_event_part

        # compute timestamp forecasting error

        scale = 0.01 # for numerical stability
        time_ground_truth = tgt_seq[:, 1:, 0] - tgt_seq[:, :-1, 0]
        time_pred = time_pred[:, :-1]

        time_error = nn.MSELoss(reduction='none')(time_pred, time_ground_truth).sum(dim=1)

        # compute event prediction error through cross entropy loss

        event_ground_truth = tgt_seq[:, 1:, 1] - 1
        event_pred = event_pred[:, :-1, :]

        event_error  = nn.CrossEntropyLoss(reduction='none', ignore_index=-1)(event_pred.transpose(1, 2), event_ground_truth).sum(dim=1)

        return (-log_likelihood + event_error + time_error * scale).mean()

In [134]:
X_batch = torch.tensor([[[1, 1], [2, 2], 
                         [4, 2], [5, 1]],
                        [[7, 1], [8, 1], 
                         [10, 2], [11, 1]]], dtype = torch.long) # X0 and X1

In [174]:
transformer = HawkesTransformer(2, torch.device('cpu'))
x, lam, preds = transformer(X_batch)