In [None]:
import sys
sys.path.append("../")

In [1]:
import numpy as np
import torch as pt
import h5py
import torch.nn as nn
import torch.nn.functional as F
from LSTM_construction import CTLSTM
from tqdm import tqdm

# First, we load the data

In [2]:
with h5py.File("SOTrainData.h5", "r") as fl:
    EventsData = np.array(fl["EventsData"])
    timesData = np.array(fl["TimesData"])
    timeMaxData = np.array(fl["TimeMaxData"])
    SeqLengthData = np.array(fl["SeqLengthData"])

In [3]:
# Make one-hot encoded events
# Make the events one hot encoded
N_train = EventsData.shape[0]
N_seq_Max = EventsData.shape[1]
N_types = np.max(EventsData) + 1
Events_one_hot = np.zeros((N_train, N_seq_Max, N_types))
for seq in range(N_train):
    for step in range(SeqLengthData[seq]):
        ev = EventsData[seq, step]
        Events_one_hot[seq, step, ev] = 1.0

In [4]:
np.max(timeMaxData)

1388.5341796875

In [5]:
# Then convert everything to torch tensors
# Since maximum times are very large here,
# we scale them down so that gradients don't explode
# This doesn't affect final results, since we can always
# scale testing times by the same amount too
timeScale = 1.0
EvTens = pt.tensor(Events_one_hot).double()
EvIndTens = pt.tensor(EventsData).long()
timeTensor = pt.tensor(timesData/timeScale).double()
tMaxTensor = pt.tensor(timeMaxData/timeScale).double()

In [6]:
assert pt.all(timeTensor[:, 0]==0)

In [8]:
# Then produce the mask
mask = EvIndTens.ge(-1+0.001)

# Now, we make the training loop

In [14]:
# Create the network
reload = False # Set to False if starting from scratch
if reload: # load the last saved network
    N_saved_last = 21 # this has to be changed if we reload from a certain time index
    try:
        net = pt.load('SO_Nets/TrainNet_SO_{}ep.pt'.format(N_saved_last))
    except:
        print("No saved network found. Starting from scratch")
        N_saved_last = 0
        net = CTLSTM(K=N_types, hD=64).double()
        
else : #initiazlize from scratch
    N_saved_last = 0
    net = CTLSTM(K=N_types, hD=64).double()

N_train = 4000
N_epoch = 45 - N_saved_last
BatchSize = 10
# Initiate Adam optimizer
optimizer = pt.optim.Adam(net.parameters(), lr=0.001)

# start training
for epoch in range(N_epoch):
    
    print("Epoch:{}".format(epoch+N_saved_last), flush=True)
    
    # Checkpoint
    pt.save(net, 'SO_Nets/TrainNet_SO_{}ep.pt'.format(epoch+N_saved_last))
    
    # permutation for this epoch
    perm = pt.randperm(N_train)
    for batchInd in tqdm(range(0, N_train, BatchSize), position=0, leave=True):
        # Gather the necessary inputs
        BatchEventsHot = EvTens[perm][batchInd:batchInd+BatchSize]
        BatchEventsInd = EvIndTens[perm][batchInd:batchInd+BatchSize]
        BatchTimes = timeTensor[perm][batchInd:batchInd+BatchSize]
        BatchTMax = tMaxTensor[perm][batchInd:batchInd+BatchSize]
        BatchMask = mask[perm][batchInd:batchInd+BatchSize]
        
        optimizer.zero_grad()
        
        # Do the forward pass
        lambOuts, CLows, Cbars, deltas, OutGates = net.forward(BatchEventsHot, BatchMask, BatchTimes)
        # Calculate the MC loss
        LMC, trandsLMC, t_upLMC = net.MC_Loss(BatchTimes, BatchTMax,
                                      CLows, Cbars, deltas, OutGates,
                                      Nsamples=N_seq_Max)
        # Calculate the likelihood
        LogLikeLoss = net.logLoss(BatchEventsInd, BatchMask, lambOuts)
        
        loss = LogLikeLoss + LMC
        
        loss.backward()
        optimizer.step()

Epoch:0


100%|██████████| 400/400 [10:15<00:00,  1.54s/it]

Epoch:1



100%|██████████| 400/400 [10:17<00:00,  1.54s/it]

Epoch:2



100%|██████████| 400/400 [10:29<00:00,  1.57s/it]

Epoch:3



100%|██████████| 400/400 [11:57<00:00,  1.79s/it]

Epoch:4



100%|██████████| 400/400 [11:55<00:00,  1.79s/it]

Epoch:5



100%|██████████| 400/400 [12:15<00:00,  1.84s/it]

Epoch:6



100%|██████████| 400/400 [11:05<00:00,  1.66s/it]

Epoch:7



100%|██████████| 400/400 [10:33<00:00,  1.58s/it]

Epoch:8



100%|██████████| 400/400 [11:47<00:00,  1.77s/it]

Epoch:9



100%|██████████| 400/400 [11:33<00:00,  1.73s/it]

Epoch:10



100%|██████████| 400/400 [12:15<00:00,  1.84s/it]

Epoch:11



100%|██████████| 400/400 [12:54<00:00,  1.94s/it]

Epoch:12



100%|██████████| 400/400 [12:44<00:00,  1.91s/it]

Epoch:13



100%|██████████| 400/400 [12:48<00:00,  1.92s/it]

Epoch:14



100%|██████████| 400/400 [12:39<00:00,  1.90s/it]

Epoch:15



100%|██████████| 400/400 [13:09<00:00,  1.97s/it]

Epoch:16



100%|██████████| 400/400 [12:45<00:00,  1.91s/it]

Epoch:17



100%|██████████| 400/400 [12:49<00:00,  1.92s/it]

Epoch:18



100%|██████████| 400/400 [13:10<00:00,  1.98s/it]

Epoch:19



100%|██████████| 400/400 [12:23<00:00,  1.86s/it]

Epoch:20



100%|██████████| 400/400 [10:50<00:00,  1.63s/it]

Epoch:21



100%|██████████| 400/400 [11:08<00:00,  1.67s/it]

Epoch:22



100%|██████████| 400/400 [10:23<00:00,  1.56s/it]

Epoch:23



100%|██████████| 400/400 [10:20<00:00,  1.55s/it]

Epoch:24



100%|██████████| 400/400 [12:40<00:00,  1.90s/it]

Epoch:25



100%|██████████| 400/400 [13:21<00:00,  2.00s/it]

Epoch:26



100%|██████████| 400/400 [13:06<00:00,  1.97s/it]

Epoch:27



100%|██████████| 400/400 [12:13<00:00,  1.83s/it]

Epoch:28



100%|██████████| 400/400 [12:22<00:00,  1.86s/it]

Epoch:29



100%|██████████| 400/400 [12:22<00:00,  1.86s/it]

Epoch:30



100%|██████████| 400/400 [12:41<00:00,  1.90s/it]

Epoch:31



100%|██████████| 400/400 [13:45<00:00,  2.06s/it]

Epoch:32



100%|██████████| 400/400 [12:42<00:00,  1.91s/it]

Epoch:33



100%|██████████| 400/400 [11:53<00:00,  1.78s/it]

Epoch:34



100%|██████████| 400/400 [10:17<00:00,  1.54s/it]

Epoch:35



100%|██████████| 400/400 [10:15<00:00,  1.54s/it]

Epoch:36



100%|██████████| 400/400 [10:22<00:00,  1.56s/it]

Epoch:37



100%|██████████| 400/400 [10:21<00:00,  1.55s/it]

Epoch:38



100%|██████████| 400/400 [11:34<00:00,  1.74s/it]

Epoch:39



100%|██████████| 400/400 [13:07<00:00,  1.97s/it]

Epoch:40



100%|██████████| 400/400 [13:09<00:00,  1.97s/it]

Epoch:41



100%|██████████| 400/400 [13:29<00:00,  2.02s/it]

Epoch:42



100%|██████████| 400/400 [13:24<00:00,  2.01s/it]

Epoch:43



100%|██████████| 400/400 [11:16<00:00,  1.69s/it]

Epoch:44



100%|██████████| 400/400 [13:14<00:00,  1.99s/it]
