In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Normal

In [2]:
torch.__version__

'1.0.1'

In [3]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


# Define RMTPP Module

In [4]:
class RMTPP(nn.Module):
    def __init__(self, marker_type='real', marker_dim=20):
        super().__init__()
        
        self.marker_dim = marker_dim
        
        # Dimensions for embedding inputs
        self.marker_embed_dim = 64
        self.time_embed_dim = 64
        # Networks for embedding inputs
        self.create_embedding_nets()
        
        # This is the layer that encodes the history
        self.hidden_layer_dim = 128
        # Create RNN layer Network
        self.create_rnn()
        
        # Hidden shared layer size (bw mu and var)
        # while generating marker from hidden state
        self.marker_shared_dim = 64
        # Create Network for Marker generation from hidden_seq
        self.create_marker_generation_net()
        self.create_time_likelihood_net()

    ############ UTILITY METHODS #############
        
    def _one_hot_marker(self, marker_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx1
        """
        pass
        
    def _reparameterize(self, mu, var):
        epsilon = torch.randn_like(mu)
        return mu + epsilon*var.sqrt()

    
    ############ NETWORKS #############    
    
    def create_rnn(self):
        """
            Input:
            marker_embed_dim: dimension of embedded markers
            time_embed_dim: dimension of embedded times,intervals
            hidden_layer_dim: dimension of hidden state of recurrent layer
        """
        self.rnn = nn.RNN(
            input_size=self.marker_embed_dim+self.time_embed_dim,
            hidden_size=self.hidden_layer_dim,
            nonlinearity='relu'
        )
    
    def create_embedding_nets(self):
        # marker_dim is passed. timeseries_dim is 2
        self.marker_embedding_net = nn.Linear(self.marker_dim, self.marker_embed_dim)
        
        self.time_embedding_net = nn.Sequential(
            nn.Linear(2, self.time_embed_dim),
            nn.ReLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )
        
    def create_marker_generation_net(self):
        """
            Generate network to create marker sufficient statistics using
            rnn's hidden layer
        """
        self.marker_gen_hidden = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.marker_shared_dim),
            nn.ReLU()
        )
        self.generated_marker_mu = nn.Linear(self.marker_shared_dim, self.marker_dim)
        self.generated_marker_var = nn.Sequential(
            nn.Linear(self.marker_shared_dim, self.marker_dim),
            nn.Softplus()
        )
    
    def create_time_likelihood_net(self):
        self.h_influence = nn.Linear(self.hidden_layer_dim, 1, bias=False)
        self.base_intensity = nn.Parameter(torch.zeros(1,1,1))
        self.wt = nn.Parameter(torch.tensor([[[.01]]]))
    
    
    ############ METHODS #############
    
    def _embed_data(self, marker_seq, time_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx marker_dim
            time_seq: Tensor of shape TxBSx 2
            Output:
            marker_seq_emb: Tensor of shape T x BS x marker_embed_dim
            time_seq_emb: Tensor of shape T x BS x time_embed_dim
        """
        marker_seq_emb = self.marker_embedding_net(marker_seq)
        time_seq_emb = self.time_embedding_net(time_seq)
        return marker_seq_emb, time_seq_emb
    
    
    
    def marker_log_likelihood(self, h_seq, marker_seq):
        """
        Use the h_seq to generate the distribution for the markers,
        and use that distribution to compute the log likelihood of the marker_seq
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    marker_seq : Tensor of shape T x BS x marker_dim
            Output:
                    log_likelihood_marker_seq : T x BS x marker_dim
        """
        marker_gen_shared = self.marker_gen_hidden(h_seq)
        mu, var = self.generated_marker_mu(marker_gen_shared), self.generated_marker_var(marker_gen_shared)
        
        marker_gen_dist = Normal(mu, var.sqrt())
        log_likelihood_marker_seq = marker_gen_dist.log_prob(marker_seq)
        
        return log_likelihood_marker_seq
        
        
    def time_log_likelihood(self, h_seq, time_seq):
        """
        Use the h_seq to compute the log likelihood of the time_seq
        using the formula in the paper
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    time_seq : Tensor of shape T x BS x 2 . Last dimension is [times, intervals]
            Output:
                    log_likelihood_time_seq : T x BS x 1
        """
        past_influence = self.h_influence(h_seq)
        current_influence = self.wt * time_seq[:,:,1:2]
        base_intensity = self.base_intensity
        
        term1 = past_influence + current_influence + base_intensity
        term2 = past_influence + base_intensity
        
        # After factorizing the formula in the paper
        log_likelihood_time_seq = term1 + (term2.exp() - term1.exp())/self.wt
        
        return log_likelihood_time_seq
    
    #############################################
    
    def forward(self, marker_seq, time_seq):
        # Transform markers and timesteps into the embedding spaces
        marker_seq_emb, time_seq_emb = self._embed_data(marker_seq, time_seq)
        T,BS,_ = marker_seq_emb.shape
        
        # Run RNN over the concatenated sequence [marker_seq_emb, time_seq_emb]
        time_marker_combined = torch.cat([marker_seq_emb, time_seq_emb], dim=-1)
        h_0 = torch.zeros(1, BS, self.hidden_layer_dim).to(device)
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        hidden_combined = torch.cat([h_0, hidden_seq], dim=0)
        
        # compute the marker and time log likelihoods
        # h_0 is used to generate marker_1 and so on...
        marker_ll = self.marker_log_likelihood(hidden_combined[:-1], marker_seq)
        # h_0 is used to generate timestamp_1 and so on...
        time_ll = self.time_log_likelihood(hidden_combined[:-1], time_seq)
        
        likelihood_loss = marker_ll.sum() + time_ll.sum()
        NLL = -likelihood_loss
        
        # NLL is used for optimization,
        # the individual LL values are used for logging
        return NLL, [-marker_ll.sum().item(), -time_ll.sum().item() ]
        

In [5]:
# _data, _ = generate_mpp(num_sample=200)
# model = RMTPP().to(device)
# model(_data['x'], _data['t'])

# Training

In [6]:
from utils_ import generate_mpp

In [7]:
from trainer import train
from rmtpp import rmtpp

In [8]:
def trainer(model, data = None, val_data=None, lr= 1e-3, epoch = 500, batch_size = 100):
    if data == None:
        data, val_data = generate_mpp()

    optimizer = Adam(model.parameters(), lr=lr)

    for epoch_number in range(epoch):
        train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model

In [9]:
def main(model, data, val_data):
    model = model().to(device)
#     data, _ = generate_mpp(type='hawkes', num_sample=1000)
#     val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
    print("Times: Data Shape: {}, Val Data Shape: {}".format(data['t'].shape, val_data['t'].shape))
    print("Markers: Data Shape: {}, Val Data Shape: {}".format(data['x'].shape, val_data['x'].shape))
    trainer(model, data=data, val_data=val_data)

In [10]:
data, _ = generate_mpp(type='hawkes', num_sample=1000)
val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
# main(model=rmtpp, data=data, val_data=val_data)

In [11]:
main(model=RMTPP, data=data, val_data=val_data)

Times: Data Shape: torch.Size([100, 1000, 2]), Val Data Shape: torch.Size([100, 200, 2])
Markers: Data Shape: torch.Size([100, 1000, 20]), Val Data Shape: torch.Size([100, 200, 20])
1000
Epoch: 0, NLL Loss: 3936.736125, Val Loss: 6414.34619140625, Time took: 0.1022803783416748
Train loss Meta Info:  [3856.21090625   80.52519922]
Val Loss Meta Info:  [6251.306875, 163.039228515625]

1000
Epoch: 1, NLL Loss: 3236.34525, Val Loss: 6128.5625, Time took: 0.09604430198669434
Train loss Meta Info:  [3150.7855       85.55972949]
Val Loss Meta Info:  [5966.84375, 161.71845703125]

1000
Epoch: 2, NLL Loss: 3078.3601875, Val Loss: 6024.30615234375, Time took: 0.08737063407897949
Train loss Meta Info:  [2993.2793125    85.08089258]
Val Loss Meta Info:  [5876.77, 147.53619140625]

1000
Epoch: 3, NLL Loss: 3023.237625, Val Loss: 5997.83935546875, Time took: 0.08651208877563477
Train loss Meta Info:  [2945.04128125   78.19630371]
Val Loss Meta Info:  [5846.67375, 151.16583984375]

1000
Epoch: 4, NLL 

Epoch: 40, NLL Loss: 2907.49253125, Val Loss: 5788.47412109375, Time took: 0.08697772026062012
Train loss Meta Info:  [2842.5341875   64.958375 ]
Val Loss Meta Info:  [5680.87875, 107.595791015625]

1000
Epoch: 41, NLL Loss: 2899.66696875, Val Loss: 5794.9443359375, Time took: 0.08910226821899414
Train loss Meta Info:  [2841.6070625    58.05992041]
Val Loss Meta Info:  [5680.930625, 114.0137109375]

1000
Epoch: 42, NLL Loss: 2903.47428125, Val Loss: 5798.025390625, Time took: 0.08643412590026855
Train loss Meta Info:  [2841.853125     61.62116846]
Val Loss Meta Info:  [5680.73625, 117.289072265625]

1000
Epoch: 43, NLL Loss: 2905.38528125, Val Loss: 5789.44287109375, Time took: 0.08694028854370117
Train loss Meta Info:  [2841.86875     63.5165127]
Val Loss Meta Info:  [5679.82, 109.623359375]

1000
Epoch: 44, NLL Loss: 2900.79259375, Val Loss: 5785.841796875, Time took: 0.08697104454040527
Train loss Meta Info:  [2841.3361875    59.45643994]
Val Loss Meta Info:  [5680.92, 104.922060546

Epoch: 82, NLL Loss: 2888.72246875, Val Loss: 5769.4443359375, Time took: 0.08829736709594727
Train loss Meta Info:  [2837.66778125   51.054646  ]
Val Loss Meta Info:  [5673.855625, 95.5890234375]

1000
Epoch: 83, NLL Loss: 2889.142875, Val Loss: 5769.94482421875, Time took: 0.08699226379394531
Train loss Meta Info:  [2837.68415625   51.45872559]
Val Loss Meta Info:  [5673.78875, 96.1565234375]

1000
Epoch: 84, NLL Loss: 2889.707875, Val Loss: 5767.26220703125, Time took: 0.09230732917785645
Train loss Meta Info:  [2837.62815625   52.07972705]
Val Loss Meta Info:  [5673.69875, 93.5634375]

1000
Epoch: 85, NLL Loss: 2888.21390625, Val Loss: 5773.72509765625, Time took: 0.0907442569732666
Train loss Meta Info:  [2837.54571875   50.66819727]
Val Loss Meta Info:  [5674.065625, 99.6591015625]

1000
Epoch: 86, NLL Loss: 2891.08921875, Val Loss: 5776.4560546875, Time took: 0.0878152847290039
Train loss Meta Info:  [2837.70715625   53.38208789]
Val Loss Meta Info:  [5674.364375, 102.0920605468

Epoch: 124, NLL Loss: 2885.82278125, Val Loss: 5772.26025390625, Time took: 0.08945155143737793
Train loss Meta Info:  [2836.68209375   49.14066406]
Val Loss Meta Info:  [5673.50125, 98.7591015625]

1000
Epoch: 125, NLL Loss: 2889.86228125, Val Loss: 5765.04052734375, Time took: 0.0892634391784668
Train loss Meta Info:  [2836.98571875   52.87661523]
Val Loss Meta Info:  [5672.86875, 92.1716015625]

1000
Epoch: 126, NLL Loss: 2886.72209375, Val Loss: 5769.63671875, Time took: 0.08978152275085449
Train loss Meta Info:  [2836.67953125   50.04255273]
Val Loss Meta Info:  [5673.106875, 96.53005859375]

1000
Epoch: 127, NLL Loss: 2889.15440625, Val Loss: 5763.119140625, Time took: 0.09087252616882324
Train loss Meta Info:  [2836.826        52.32843262]
Val Loss Meta Info:  [5672.8275, 90.291962890625]

1000
Epoch: 128, NLL Loss: 2885.62203125, Val Loss: 5771.98681640625, Time took: 0.09038186073303223
Train loss Meta Info:  [2836.63128125   48.99076367]
Val Loss Meta Info:  [5673.19625, 98.7

Epoch: 165, NLL Loss: 2886.96425, Val Loss: 5760.61865234375, Time took: 0.08953738212585449
Train loss Meta Info:  [2836.20315625   50.76112012]
Val Loss Meta Info:  [5672.7425, 87.8765625]

1000
Epoch: 166, NLL Loss: 2883.7198125, Val Loss: 5763.7890625, Time took: 0.0901181697845459
Train loss Meta Info:  [2836.0875625    47.63227002]
Val Loss Meta Info:  [5672.72875, 91.06068359375]

1000
Epoch: 167, NLL Loss: 2885.509125, Val Loss: 5760.4921875, Time took: 0.08948516845703125
Train loss Meta Info:  [2836.08784375   49.42132422]
Val Loss Meta Info:  [5672.695625, 87.79677734375]

1000
Epoch: 168, NLL Loss: 2883.61246875, Val Loss: 5765.517578125, Time took: 0.0898289680480957
Train loss Meta Info:  [2836.045625     47.56683203]
Val Loss Meta Info:  [5672.73, 92.78720703125]

1000
Epoch: 169, NLL Loss: 2885.67271875, Val Loss: 5760.572265625, Time took: 0.08936762809753418
Train loss Meta Info:  [2836.01603125   49.65668359]
Val Loss Meta Info:  [5672.67875, 87.89380859375]

1000
Ep

Epoch: 207, NLL Loss: 2886.32509375, Val Loss: 5766.77001953125, Time took: 0.0851600170135498
Train loss Meta Info:  [2837.1385      49.1866001]
Val Loss Meta Info:  [5673.88625, 92.88380859375]

1000
Epoch: 208, NLL Loss: 2886.5539375, Val Loss: 5763.6328125, Time took: 0.08411026000976562
Train loss Meta Info:  [2836.74378125   49.81013086]
Val Loss Meta Info:  [5673.39125, 90.241572265625]

1000
Epoch: 209, NLL Loss: 2885.05721875, Val Loss: 5764.345703125, Time took: 0.08420825004577637
Train loss Meta Info:  [2836.3784375   48.6787876]
Val Loss Meta Info:  [5673.29125, 91.0542578125]

1000
Epoch: 210, NLL Loss: 2885.57990625, Val Loss: 5763.74755859375, Time took: 0.08472466468811035
Train loss Meta Info:  [2836.25453125   49.32538232]
Val Loss Meta Info:  [5673.746875, 90.0009375]

1000
Epoch: 211, NLL Loss: 2885.17459375, Val Loss: 5764.2646484375, Time took: 0.08462929725646973
Train loss Meta Info:  [2836.479875     48.69474365]
Val Loss Meta Info:  [5674.215, 90.04998046875]

Epoch: 249, NLL Loss: 2889.07609375, Val Loss: 5809.40625, Time took: 0.08435535430908203
Train loss Meta Info:  [2835.255        53.82112305]
Val Loss Meta Info:  [5673.5075, 135.89857421875]

1000
Epoch: 250, NLL Loss: 2907.926625, Val Loss: 5830.09912109375, Time took: 0.08388233184814453
Train loss Meta Info:  [2835.73184375   72.19482275]
Val Loss Meta Info:  [5674.245, 155.854140625]

1000
Epoch: 251, NLL Loss: 2918.48178125, Val Loss: 5817.93603515625, Time took: 0.08357024192810059
Train loss Meta Info:  [2836.25590625   82.22585449]
Val Loss Meta Info:  [5675.4025, 142.53388671875]

1000
Epoch: 252, NLL Loss: 2911.95575, Val Loss: 5796.1474609375, Time took: 0.08407998085021973
Train loss Meta Info:  [2836.847125     75.10861328]
Val Loss Meta Info:  [5676.928125, 119.21912109375]

1000
Epoch: 253, NLL Loss: 2900.48834375, Val Loss: 5789.27490234375, Time took: 0.08412718772888184
Train loss Meta Info:  [2837.56140625   62.92694189]
Val Loss Meta Info:  [5676.68125, 112.593945

Epoch: 291, NLL Loss: 2883.1759375, Val Loss: 5762.8115234375, Time took: 0.08505773544311523
Train loss Meta Info:  [2835.6406875    47.53524854]
Val Loss Meta Info:  [5674.5575, 88.254423828125]

1000
Epoch: 292, NLL Loss: 2883.0149375, Val Loss: 5762.48486328125, Time took: 0.08426642417907715
Train loss Meta Info:  [2835.612        47.40293701]
Val Loss Meta Info:  [5674.50875, 87.9764453125]

1000
Epoch: 293, NLL Loss: 2882.8045625, Val Loss: 5762.19482421875, Time took: 0.08382797241210938
Train loss Meta Info:  [2835.53259375   47.27194189]
Val Loss Meta Info:  [5674.418125, 87.7765625]

1000
Epoch: 294, NLL Loss: 2882.59321875, Val Loss: 5761.84912109375, Time took: 0.08666348457336426
Train loss Meta Info:  [2835.45440625   47.13879785]
Val Loss Meta Info:  [5674.32125, 87.527890625]

1000
Epoch: 295, NLL Loss: 2882.4239375, Val Loss: 5761.60986328125, Time took: 0.08593440055847168
Train loss Meta Info:  [2835.3851875    47.03875879]
Val Loss Meta Info:  [5674.31125, 87.29890

Epoch: 334, NLL Loss: 2884.7388125, Val Loss: 5766.36279296875, Time took: 0.08858656883239746
Train loss Meta Info:  [2835.57625      49.16259717]
Val Loss Meta Info:  [5675.0575, 91.305546875]

1000
Epoch: 335, NLL Loss: 2884.403625, Val Loss: 5765.9716796875, Time took: 0.10540175437927246
Train loss Meta Info:  [2835.34859375   49.05502734]
Val Loss Meta Info:  [5674.65125, 91.3208203125]

1000
Epoch: 336, NLL Loss: 2884.10965625, Val Loss: 5765.7578125, Time took: 0.09566640853881836
Train loss Meta Info:  [2835.103375     49.00624463]
Val Loss Meta Info:  [5674.764375, 90.993828125]

1000
Epoch: 337, NLL Loss: 2883.9609375, Val Loss: 5765.52490234375, Time took: 0.09354472160339355
Train loss Meta Info:  [2835.1241875    48.83672266]
Val Loss Meta Info:  [5674.87875, 90.646064453125]

1000
Epoch: 338, NLL Loss: 2883.86215625, Val Loss: 5765.32666015625, Time took: 0.09282302856445312
Train loss Meta Info:  [2835.1421875   48.7199624]
Val Loss Meta Info:  [5674.7475, 90.5796484375

Epoch: 375, NLL Loss: 2879.53759375, Val Loss: 5759.96435546875, Time took: 0.08824396133422852
Train loss Meta Info:  [2834.0618125    45.47579858]
Val Loss Meta Info:  [5675.14875, 84.81578125]

1000
Epoch: 376, NLL Loss: 2879.4778125, Val Loss: 5759.8330078125, Time took: 0.08854055404663086
Train loss Meta Info:  [2834.0484375    45.42938086]
Val Loss Meta Info:  [5675.14, 84.693203125]

1000
Epoch: 377, NLL Loss: 2879.4106875, Val Loss: 5759.794921875, Time took: 0.08893990516662598
Train loss Meta Info:  [2834.0225625    45.38814453]
Val Loss Meta Info:  [5675.146875, 84.648310546875]

1000
Epoch: 378, NLL Loss: 2879.3364375, Val Loss: 5759.72412109375, Time took: 0.09372830390930176
Train loss Meta Info:  [2834.00115625   45.33530469]
Val Loss Meta Info:  [5675.154375, 84.5697265625]

1000
Epoch: 379, NLL Loss: 2879.27221875, Val Loss: 5759.62744140625, Time took: 0.1105043888092041
Train loss Meta Info:  [2833.98121875   45.29102124]
Val Loss Meta Info:  [5675.16375, 84.4638671

Epoch: 416, NLL Loss: 2878.88846875, Val Loss: 5761.03564453125, Time took: 0.08470940589904785
Train loss Meta Info:  [2833.20778125   45.68072852]
Val Loss Meta Info:  [5676.129375, 84.9059765625]

1000
Epoch: 417, NLL Loss: 2878.8114375, Val Loss: 5762.44287109375, Time took: 0.08428287506103516
Train loss Meta Info:  [2833.19278125   45.61868604]
Val Loss Meta Info:  [5676.26875, 86.174208984375]

1000
Epoch: 418, NLL Loss: 2879.07478125, Val Loss: 5760.28369140625, Time took: 0.08430075645446777
Train loss Meta Info:  [2833.19365625   45.88115723]
Val Loss Meta Info:  [5676.223125, 84.060703125]

1000
Epoch: 419, NLL Loss: 2878.2158125, Val Loss: 5760.7822265625, Time took: 0.08466935157775879
Train loss Meta Info:  [2833.14284375   45.0729668 ]
Val Loss Meta Info:  [5676.25375, 84.5287890625]

1000
Epoch: 420, NLL Loss: 2878.5190625, Val Loss: 5760.80029296875, Time took: 0.0847160816192627
Train loss Meta Info:  [2833.1258125   45.3932002]
Val Loss Meta Info:  [5676.379375, 84.4

Epoch: 457, NLL Loss: 2883.15053125, Val Loss: 5768.45068359375, Time took: 0.08549809455871582
Train loss Meta Info:  [2834.95596875   48.19451367]
Val Loss Meta Info:  [5678.956875, 89.493818359375]

1000
Epoch: 458, NLL Loss: 2882.6681875, Val Loss: 5766.93994140625, Time took: 0.0835576057434082
Train loss Meta Info:  [2834.63934375   48.02885889]
Val Loss Meta Info:  [5677.795625, 89.144296875]

1000
Epoch: 459, NLL Loss: 2881.738125, Val Loss: 5766.72998046875, Time took: 0.08517026901245117
Train loss Meta Info:  [2833.89096875   47.8471499 ]
Val Loss Meta Info:  [5678.03625, 88.69345703125]

1000
Epoch: 460, NLL Loss: 2881.5758125, Val Loss: 5767.43359375, Time took: 0.08397865295410156
Train loss Meta Info:  [2833.917375     47.65844824]
Val Loss Meta Info:  [5678.909375, 88.524228515625]

1000
Epoch: 461, NLL Loss: 2881.963875, Val Loss: 5767.2890625, Time took: 0.09560608863830566
Train loss Meta Info:  [2834.31525      47.64864404]
Val Loss Meta Info:  [5678.82375, 88.46569

Epoch: 498, NLL Loss: 2877.37571875, Val Loss: 5763.0556640625, Time took: 0.08554935455322266
Train loss Meta Info:  [2832.31765625   45.05813672]
Val Loss Meta Info:  [5678.96, 84.095361328125]

1000
Epoch: 499, NLL Loss: 2877.32203125, Val Loss: 5763.044921875, Time took: 0.10144877433776855
Train loss Meta Info:  [2832.2974375    45.02455322]
Val Loss Meta Info:  [5678.99625, 84.0486328125]

