In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions import Normal

In [2]:
torch.__version__

'1.0.1'

In [3]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


# Define RMTPP Module

In [4]:
class RMTPP(nn.Module):
    def __init__(self, marker_type='real', marker_dim=20):
        super().__init__()
        
        self.marker_dim = marker_dim
        
        # Dimensions for embedding inputs
        self.marker_embed_dim = 64
        self.time_embed_dim = 64
        # Networks for embedding inputs
        self.create_embedding_nets()
        
        # This is the layer that encodes the history
        self.hidden_layer_dim = 128
        # Create RNN layer Network
        self.create_rnn()
        
        # Hidden shared layer size (bw mu and logvar)
        # while generating marker from hidden state
        self.marker_shared_dim = 64
        # Create Network for Marker generation from hidden_seq
        self.create_marker_generation_net()
        self.create_time_likelihood_net()

    ############ UTILITY METHODS #############
        
    def _one_hot_marker(self, marker_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx1
        """
        pass
        
    def _reparameterize(self, mu, logvar):
        epsilon = torch.randn_like(mu)
        return mu + epsilon*logvar.exp().sqrt()

    
    ############ NETWORKS #############    
    
    def create_rnn(self):
        """
            Input:
            marker_embed_dim: dimension of embedded markers
            time_embed_dim: dimension of embedded times,intervals
            hidden_layer_dim: dimension of hidden state of recurrent layer
        """
        self.rnn = nn.RNN(
            input_size=self.marker_embed_dim+self.time_embed_dim,
            hidden_size=self.hidden_layer_dim,
            nonlinearity='relu'
        )
    
    def create_embedding_nets(self):
        # marker_dim is passed. timeseries_dim is 2
        self.marker_embedding_net = nn.Linear(self.marker_dim, self.marker_embed_dim)
        
        self.time_embedding_net = nn.Sequential(
            nn.Linear(2, self.time_embed_dim),
            nn.ReLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim)
        )
        
    def create_marker_generation_net(self):
        """
            Generate network to create marker sufficient statistics using
            rnn's hidden layer
        """
        self.marker_gen_hidden = nn.Sequential(
            nn.Linear(self.hidden_layer_dim, self.marker_shared_dim),
            nn.ReLU()
        )
        self.generated_marker_mu = nn.Linear(self.marker_shared_dim, self.marker_dim)
        self.generated_marker_logvar = nn.Sequential(
            nn.Linear(self.marker_shared_dim, self.marker_dim),
            nn.Softplus()
        )
    
    def create_time_likelihood_net(self):
        self.h_influence = nn.Linear(self.hidden_layer_dim, 1, bias=False)
        self.base_intensity = nn.Parameter(torch.zeros(1,1,1))
        self.wt = nn.Parameter(torch.ones(1,1,1))
    
    
    ############ METHODS #############
    
    def _embed_data(self, marker_seq, time_seq):
        """
            Input:
            marker_seq: Tensor of shape TxBSx marker_dim
            time_seq: Tensor of shape TxBSx 2
            Output:
            marker_seq_emb: Tensor of shape T x BS x marker_embed_dim
            time_seq_emb: Tensor of shape T x BS x time_embed_dim
        """
        marker_seq_emb = self.marker_embedding_net(marker_seq)
        time_seq_emb = self.time_embedding_net(time_seq)
        return marker_seq_emb, time_seq_emb
    
    
    
    def marker_log_likelihood(self, h_seq, marker_seq):
        """
        Use the h_seq to generate the distribution for the markers,
        and use that distribution to compute the log likelihood of the marker_seq
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    marker_seq : Tensor of shape T x BS x marker_dim
            Output:
                    log_likelihood_marker_seq : T x BS x marker_dim
        """
        marker_gen_shared = self.marker_gen_hidden(h_seq)
        mu, logvar = self.generated_marker_mu(marker_gen_shared), self.generated_marker_logvar(marker_gen_shared)
        
        marker_gen_dist = Normal(mu, logvar.exp().sqrt())
        log_likelihood_marker_seq = marker_gen_dist.log_prob(marker_seq)
        
        return log_likelihood_marker_seq
        
        
    def time_log_likelihood(self, h_seq, time_seq):
        """
        Use the h_seq to compute the log likelihood of the time_seq
        using the formula in the paper
            Input:  
                    h_seq   : Tensor of shape T x BS x hidden_layer_dim (if real)
                    time_seq : Tensor of shape T x BS x 2 . Last dimension is [times, intervals]
            Output:
                    log_likelihood_time_seq : T x BS x 1
        """
        past_influence = self.h_influence(h_seq)
        current_influence = self.wt * time_seq[:,:,1:2]
        base_intensity = self.base_intensity
        
        term1 = past_influence + current_influence + base_intensity
        term2 = past_influence + base_intensity
        
        # After factorizing the formula in the paper
        log_likelihood_time_seq = term1 + (term2.exp() - term1.exp())/self.wt
        
        return log_likelihood_time_seq
    
    def forward(self, marker_seq, time_seq):
        # Transform markers and timesteps into the embedding spaces
        marker_seq_emb, time_seq_emb = self._embed_data(marker_seq, time_seq)
        T,BS,_ = marker_seq_emb.shape
        
        # Run RNN over the concatenated sequence [marker_seq_emb, time_seq_emb]
        time_marker_combined = torch.cat([marker_seq_emb, time_seq_emb], dim=-1)
        h_0 = torch.zeros(1, BS, self.hidden_layer_dim).to(device)
        hidden_seq, _ = self.rnn(time_marker_combined, h_0)
        
        # compute the marker and time log likelihoods
        marker_ll = self.marker_log_likelihood(hidden_seq, marker_seq)
        time_ll = self.time_log_likelihood(hidden_seq, time_seq)
        
        likelihood_loss = marker_ll.sum() + time_ll.sum()
        NLL = -likelihood_loss
        return NLL, [-marker_ll.sum().item(), -time_ll.sum().item() ]
        

In [5]:
# _data, _ = generate_mpp(num_sample=200)
# model = RMTPP().to(device)
# model(_data['x'], _data['t'])

# Training

In [6]:
from utils_ import generate_mpp

In [17]:
from trainer import train
from rmtpp import rmtpp

In [18]:
def trainer(model, data = None, val_data=None, lr= 1e-3, epoch = 500, batch_size = 64):
    if data == None:
        data, val_data = generate_mpp()

    optimizer = Adam(model.parameters(), lr=lr)

    for epoch_number in range(epoch):
        train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model

In [21]:
def main(model, data, val_data):
    model = model().to(device)
#     data, _ = generate_mpp(type='hawkes', num_sample=1000)
#     val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
    print("Times: Data Shape: {}, Val Data Shape: {}".format(data['t'].shape, val_data['t'].shape))
    print("Markers: Data Shape: {}, Val Data Shape: {}".format(data['x'].shape, val_data['x'].shape))
    trainer(model, data=data, val_data=val_data)

In [None]:
data, _ = generate_mpp(type='hawkes', num_sample=1000)
val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
main(model=rmtpp, data=data, val_data=val_data)

Times: Data Shape: torch.Size([100, 1000, 2]), Val Data Shape: torch.Size([100, 200, 2])
Markers: Data Shape: torch.Size([100, 1000, 20]), Val Data Shape: torch.Size([100, 200, 20])
Epoch: 0, NLL Loss: 18633445881.36, Val Loss: 77549176.0, Time took: 1.2529633045196533
Train loss Meta Info:  [114.342859375, 73973.088]
Val Loss Meta Info:  [5703.70625, 77543480.32]

Epoch: 1, NLL Loss: 17090740647.76, Val Loss: 71676648.0, Time took: 1.2832868099212646
Train loss Meta Info:  [114.31271875, 69776.216]
Val Loss Meta Info:  [5702.495, 71670958.08]

Epoch: 2, NLL Loss: 15706279251.156, Val Loss: 66108732.0, Time took: 1.1665372848510742
Train loss Meta Info:  [114.287203125, 65687.616]
Val Loss Meta Info:  [5700.2225, 66103040.0]

Epoch: 3, NLL Loss: 14370354843.888, Val Loss: 60820172.0, Time took: 1.4386906623840332
Train loss Meta Info:  [114.240140625, 61854.432]
Val Loss Meta Info:  [5698.41625, 60814484.48]

Epoch: 4, NLL Loss: 13163647470.644, Val Loss: 55654144.0, Time took: 1.51972

Epoch: 44, NLL Loss: 70089975.363125, Val Loss: 521759.28125, Time took: 1.000394344329834
Train loss Meta Info:  [113.634453125, 609.01625]
Val Loss Meta Info:  [5674.523125, 516084.8]

Epoch: 45, NLL Loss: 61167920.16453125, Val Loss: 462509.1875, Time took: 1.2740373611450195
Train loss Meta Info:  [113.6175703125, 538.796625]
Val Loss Meta Info:  [5673.7925, 456835.44]

Epoch: 46, NLL Loss: 53525516.587875, Val Loss: 410513.4375, Time took: 1.3330769538879395
Train loss Meta Info:  [113.6046015625, 477.9215]
Val Loss Meta Info:  [5673.340625, 404840.0]

Epoch: 47, NLL Loss: 46973814.953875, Val Loss: 365185.5625, Time took: 1.2974839210510254
Train loss Meta Info:  [113.6016875, 425.1186875]
Val Loss Meta Info:  [5673.30375, 359512.16]

Epoch: 48, NLL Loss: 41349963.09453125, Val Loss: 325865.8125, Time took: 1.5790579319000244
Train loss Meta Info:  [113.60975, 379.3118125]
Val Loss Meta Info:  [5673.5875, 320192.22]

Epoch: 49, NLL Loss: 36506859.68409375, Val Loss: 291748.84375,

Epoch: 87, NLL Loss: 1431104.223234375, Val Loss: 26480.173828125, Time took: 0.7957284450531006
Train loss Meta Info:  [113.573703125, 43.47289453125]
Val Loss Meta Info:  [5671.495625, 20808.68]

Epoch: 88, NLL Loss: 1357731.061875, Val Loss: 25793.78515625, Time took: 0.7976751327514648
Train loss Meta Info:  [113.573046875, 42.96184375]
Val Loss Meta Info:  [5671.48625, 20122.3025]

Epoch: 89, NLL Loss: 1289783.8505625, Val Loss: 25154.779296875, Time took: 0.8050529956817627
Train loss Meta Info:  [113.5711796875, 42.49659375]
Val Loss Meta Info:  [5671.49125, 19483.29]

Epoch: 90, NLL Loss: 1226789.1778125, Val Loss: 24559.1015625, Time took: 0.7849588394165039
Train loss Meta Info:  [113.570640625, 42.07246484375]
Val Loss Meta Info:  [5671.466875, 18887.635]

Epoch: 91, NLL Loss: 1168302.4654375, Val Loss: 24003.173828125, Time took: 1.0338473320007324
Train loss Meta Info:  [113.5697578125, 41.68563671875]
Val Loss Meta Info:  [5671.458125, 18331.715]

Epoch: 92, NLL Loss: 111