In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
from LSTM_construction import CTLSTM
import numpy as np

In [2]:
# Load the data
try:
    Events = np.load("Events.npy")
    Times = np.load("Times.npy")
    Mu = np.load("Mu.npy")
    alpha = np.load("alpha.npy")
    delta = np.load("Delta.npy")
except:
    raise FileNotFoundError("Required files not found. Please Run the SE_MPP-synthetic notebook first.")

In [3]:
Events.shape, Times.shape

((100, 50), (100, 51))

In [4]:
# Let's make a random mask to simulate variable lengths of the sequences
mask = pt.ones(Events.shape[0], Events.shape[1], dtype=pt.int8)
# Now let's assign some random lengths to them - minimum length 10 (just a choice)
SeqLens = pt.zeros(Events.shape[0]).int()
tMax = pt.zeros(Events.shape[0])
for i in range(mask.shape[0]):
    rn = np.random.randint(10, Events.shape[1])
    SeqLens[i] = rn
    tMax[i] = Times[i, rn]
    mask[i, rn:] = 0

mask = mask.bool()

In [5]:
# Make the events one hot encoded
Events_one_hot = np.zeros((Events.shape[0], Events.shape[1], 5))
for seq in range(Events.shape[0]):
    for step in range(Events.shape[1]):
        ev = Events[seq, step]
        Events_one_hot[seq, step, ev] = 1.0

In [6]:
# Next, convert everything to tensors
EvTens = pt.tensor(Events_one_hot)
EvIndTens = pt.tensor(Events)
timeTensor = pt.tensor(Times)

In [7]:
# Create the network - 5 event types - 64 x 64 matrix
net = CTLSTM(K=5, hD=64).double()

In [8]:
# Let's test the network for a minibatch of 10 samples
N_batch = 10
EvBatch = EvTens[:N_batch].double()
maskBatch = mask[:N_batch]
EvIndBatch = EvIndTens[:N_batch].long()
timeBatch = timeTensor[:N_batch].double()

In [9]:
EvBatch.shape, timeBatch.shape

(torch.Size([10, 50, 5]), torch.Size([10, 51]))

In [10]:
# %%timeit
# 1. Do the forward pass
lambOuts, CLows, Cbars, deltas, OutGates = net.forward(EvBatch, maskBatch, timeBatch)

In [11]:
# %%timeit
# 2. Do the MC Integration
# The function returns all the random times and
# sorting indices that were used so that we can
# test for correctness
# times, tMax, Clows, Cbars, deltas, OutGates, Nsamples=1000
LMC, trandsLMC, t_upLMC = net.MC_Loss(timeBatch, tMax[:N_batch], CLows, Cbars, deltas, OutGates, Nsamples=500)

In [12]:
# 3. Evaluate the log likelihood
LogLikeLoss = net.logLoss(EvIndBatch, maskBatch, lambOuts)

# We test the three components - log-likelihood calculation, MC integration and forward pass

## First, the log likelihood - the easiest one
## Goal is to go through each sequence individually and check if our batch processing results, done with advanced indexing match with that.

In [13]:
# Let's test the log likelihood loss
loss_total = 0

# Go through each sequence
for i in range(N_batch):
    # Get the events that occurred in this sequence
    seq = Events[i][mask[i]]
    loss_seq = 0.
    # Go through the events
    for evInd in range(seq.shape[0]):
        # Get the event type (0, 1, 2, 3 or 4)
        ev = seq[evInd]
        # Add the log of the intensity of the event type
        # to the log-likelihood of the current sequence
        lamb = lambOuts[i, evInd, ev]
        loss_seq += pt.log(lamb)
    # Add the log-likelihood of this sequence
    # to the total log-likelihood of the batch
    loss_total -= loss_seq

In [14]:
# Check if the results match with batch processing results
assert pt.allclose(loss_total/N_batch, LogLikeLoss)
print("Log likelihood part correct")

Log likelihood part correct


## Second, the MC integration part

In [15]:
# First get the indices for upper time limit
t_up = pt.searchsorted(timeBatch, trandsLMC)
assert pt.equal(t_up, t_upLMC)

In [16]:
trandsLMC.shape

torch.Size([10, 500])

In [17]:
# Now, we'll evaluate the integral individually for each sequence

# Get the no. of random time samples that were drawn
Nsamples = trandsLMC.shape[1]
lamb_total = 0.
# Go through each sequence
for i in range(N_batch):
    # Get the times for this sequence
    times_seq = timeBatch[i, :SeqLens[i] + 1]
    # Go through the randomly selected times
    lamb_tot_seq = 0.
    for sample in range(Nsamples):
        t_drawn = trandsLMC[i, sample]
        # Get the index of this times
        idx = pt.searchsorted(times_seq, t_drawn)
        # Check if the index is correct
        assert idx == t_up[i, sample]
        
        # Next, we need to evaluate the total rate at this time
        # To do this, we need the LSTM values for this time stamp
        tlow = times_seq[idx - 1]
        clow = CLows[i, idx - 1]
        cbar = Cbars[i, idx - 1]
        o = OutGates[i, idx - 1]
        delta = deltas[i, idx - 1]
        
        # We then compute c(t)
        ct = cbar + (clow - cbar)*pt.exp(-(t_drawn - tlow)*delta)
        
        # Then h(t)
        ht = o * (2*net.sigma(2*ct) - 1)
        
        # Then lambda_tilde
        lamb_til = net.L_lamb_til(ht)
        
        # Then lambda
        # evaluate this element-wise too - good for checking correctness
        lamb = pt.zeros(5)
        for k in range(5):
            s_k = net.scale[k]
            lamb[k] = s_k * pt.log(1 + pt.exp(lamb_til[k] / s_k))
        
        # Add the total intensity at this random time to the total for this sequence
        lamb_tot_seq += pt.sum(lamb)
    
    # Add the lambda*End_time for this sequence to the total
    lamb_total += lamb_tot_seq*times_seq[-1]/Nsamples

In [18]:
assert pt.allclose(lamb_total/N_batch, LMC.double())
print("MC integral result correct")

MC integral result correct


## So what we have till now is that if the forward pass is correct, then the Log-likelihhod and Monte Carlo integral estimation are also correct.
## So, Third, now we confirm that the forward pass is correct.
## As before, we'll do the forward pass explicitly sample by sample and confirm that our results match with the batch processing results of the neural network which uses advanced pytorch indexing for faster computation.

In [19]:
# Let's do the forward pass sample by sample.
N_events = Events.shape[1]
hD = net.hD

for seq in range(N_batch):
    # Get the on-hot encoded sequence
    Events_seq = EvBatch[seq, :SeqLens[seq]]
    times_seq = timeBatch[seq, :SeqLens[seq] + 1]
    
    # Initialize ht, ct and cbar to zero
    ht = pt.zeros(hD).double()
    ct = pt.zeros(hD).double()
    cbar = pt.zeros(hD).double()
    # Now go through this sequence
    for evInd in range(SeqLens[seq]):
        
        # Get the one-hot encoded event
        event = Events_seq[evInd]
        # get the linear output
        Linout = net.L_U(event) + net.L_V(ht)
        
        # get the non linear outputs
        i, f = net.sigma(Linout[:hD]), net.sigma(Linout[hD:2*hD])
            
        iBar, fBar = net.sigma(Linout[2*hD:3*hD]), net.sigma(Linout[3*hD:4*hD])

        # Remember to multiply "z" by 2
        z, o = 2*net.sigma(Linout[4*hD:5*hD]), net.sigma(Linout[5*hD:6*hD])

        delta = F.softplus(Linout[6*hD:7*hD])
        
        clow = f * ct + i * z
        cbar = fBar * cbar + iBar * z
        
        
        tnow = times_seq[evInd]
        tnext = times_seq[evInd + 1]
        
        ct = cbar + (clow - cbar)*pt.exp(-(tnext - tnow)*delta)
        
        ht = o * (2*net.sigma(2*ct) - 1)
        
        lamb_til = net.L_lamb_til(ht)
        
        lamb = pt.zeros(5).double()
        for k in range(5):
            s_k = net.scale[k]
            lamb[k] = s_k * pt.log(1 + pt.exp(lamb_til[k] / s_k))
        
        # Now check correctness
        assert pt.allclose(delta, deltas[seq, evInd]), "{} {}".format(seq, evInd)
        assert pt.allclose(clow, CLows[seq, evInd]), "{} {}".format(seq, evInd)
        assert pt.allclose(cbar, Cbars[seq, evInd]), "{} {}".format(seq, evInd)
        assert pt.allclose(o, OutGates[seq, evInd]), "{} {}".format(seq, evInd)
        assert pt.allclose(lamb, lambOuts[seq, evInd]), "{} {}".format(seq, evInd)
        
print("Forward pass sample by sample and step by step checks complete")

Forward pass sample by sample and step by step checks complete
