In [None]:
IN_COLAB = True if 'google.colab' in str(get_ipython()) else False

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils_data

In [None]:
if IN_COLAB:
  !pip install -q pytorch-lightning

import pytorch_lightning as pl

In [None]:
# ALL GLOBAL VARIABLES

GLOBAL_DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
GLOBAL_SEED = 42
PADDING_CONST = 0

In [None]:
class HawkesTransformer(nn.Module):

    def __init__(self, n_event_types, device, d_model=512, n_heads=8, n_layers=6, dropout=0.1):
        """
        Input parameters:
          n_event_types (int) - number of event types in the data,
          d_model (int) - size of model's latent dimension,
        """
        super(HawkesTransformer, self).__init__()

        self.d_model = d_model
        self.device = device

        # initialize div term for temporal encoding
        self.init_temporal_encoding()

        # event type embedding
        self.event_embedding = nn.Embedding(n_event_types + 1, d_model, padding_idx=PADDING_CONST)

        # transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads)
        layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.transformer_layers = nn.TransformerEncoder(encoder_layer, n_layers, norm=layer_norm)

        # output prediction layers
        self.time_predictor  = nn.Linear(d_model, 1)
        self.event_predictor = nn.Linear(d_model, n_event_types)

    def generate_subsequent_mask(self, seq):
        """
        Function to generate masking for the subsequent information in the sequences (masked self-attention)
        """
        bs, ls = seq.size()
        subsequent_mask = ~torch.triu( torch.ones(ls, ls, device=self.device, dtype=torch.bool), diagonal=1 )
        
        return subsequent_mask
    
    def generate_key_padding_mask(self, seq):
        """
        Masking the padded part of the sequence.
        """
        ls = seq.size(1)
        padding_mask = seq.eq(PADDING_CONST)

        return padding_mask

    def init_temporal_encoding(self):
        encoding_constant = torch.tensor(10000.0)

        # for better numerical stability
        self.te_div_term = torch.exp(2.0 * (torch.arange(0, self.d_model) // 2) * -torch.log(encoding_constant) / self.d_model).to(self.device)
  
    def temporal_encoding(self, t, non_padded_mask):
        temporal_enc = t.unsqueeze(-1) * self.te_div_term

        temporal_enc[:, :, 0::2] = torch.sin(temporal_enc[:, :, 0::2])
        temporal_enc[:, :, 1::2] = torch.cos(temporal_enc[:, :, 1::2])

        return temporal_enc * non_padded_mask.unsqueeze(-1)
    
    def forward(self, input_seq):
        """
        Input:
          input_seq (B, S, F) - input sequence of size (batch size, sequence length, features)
        """
        bs, ls, nf = input_seq.size()

        # generate masks
        src_key_padding_mask = self.generate_key_padding_mask(input_seq[:,:,1])
        src_non_padded_mask = ~src_key_padding_mask
        src_mask = self.generate_subsequent_mask(input_seq[:,:,1])

        # perform encodings
        temp_enc  = self.temporal_encoding(input_seq[:,:,0], src_non_padded_mask)
        event_enc = self.event_embedding(input_seq[:,:,1])

        # make pass through transformer encoder layers
        x = event_enc + temp_enc
        x = self.transformer_layers(x)

        time_pred  = self.time_predictor(x).squeeze(2) * src_non_padded_mask
        event_pred = self.event_predictor(x) * src_non_padded_mask.unsqueeze(-1)

        return x

In [None]:
test_batch = torch.randint(low=1, high=3, size=(16, 20, 2))

transformer = HawkesTransformer(2, torch.device('cpu'))
transformer(test_batch).size()