In [None]:
import sys
sys.path.append("../")

In [None]:
import numpy as np
import torch as pt
import h5py
import torch.nn as nn
import torch.nn.functional as F
from LSTM_construction import CTLSTM
from tqdm import tqdm

# First, we load the data

In [None]:
with h5py.File("RetweetTrainData.h5", "r") as fl:
    EventsData = np.array(fl["EventsData"])
    timesData = np.array(fl["TimesData"])
    timeMaxData = np.array(fl["TimeMaxData"])
    SeqLengthData = np.array(fl["SeqLengthData"])

In [None]:
# Make one-hot encoded events
# Make the events one hot encoded
N_train = EventsData.shape[0]
N_seq_Max = EventsData.shape[1]
N_types = 3
Events_one_hot = np.zeros((N_train, N_seq_Max, N_types))
for seq in range(N_train):
    for step in range(SeqLengthData[seq]):
        ev = EventsData[seq, step]
        Events_one_hot[seq, step, ev] = 1.0

In [None]:
np.max(timeMaxData)

604798.0

In [None]:
# Then convert everything to torch tensors
# Since maximum times are very large here,
# we scale them down so that gradients don't explode
# This doesn't affect final results, since we can always
# scale testing times by the same amount too
timeScale = 1.0
EvTens = pt.tensor(Events_one_hot).double()
EvIndTens = pt.tensor(EventsData).long()
timeTensor = pt.tensor(timesData/timeScale).double()
tMaxTensor = pt.tensor(timeMaxData/timeScale).double()

In [None]:
assert pt.all(timeTensor[:, 0]==0)

In [None]:
# Then produce the mask
mask = EvIndTens.ge(-1+0.001)

# Now, we make the training loop

In [None]:
# Create the network
reload = True # Set to False if starting from scratch
if reload: # load the last saved network
    N_saved_last = 21 # this has to be changed when we reload from a certain time index
    try:
        net = pt.load('RetweetNets/TrainNet_{}ep.pt'.format(N_saved_last))
    except:
        print("No saved network found. Starting from scratch")
        N_saved_last = 0
        net = CTLSTM(K=N_types, hD=64).double()
        
else : #initiazlize from scratch
    N_saved_last = 0
    net = CTLSTM(K=N_types, hD=64).double()

N_train = 4000
N_epoch = 50 - N_saved_last
BatchSize = 10
# Initiate Adam optimizer
optimizer = pt.optim.Adam(net.parameters(), lr=0.001)

# start training
for epoch in range(N_epoch):
    
    print("Epoch:{}".format(epoch+N_saved_last), flush=True)
    
    # Checkpoint
    pt.save(net, 'RetweetNets/TrainNet_{}ep.pt'.format(epoch+N_saved_last))
    
    # permutation for this epoch
    perm = pt.randperm(N_train)
    for batchInd in tqdm(range(0, N_train, BatchSize), position=0, leave=True):
        # Gather the necessary inputs
        BatchEventsHot = EvTens[perm][batchInd:batchInd+BatchSize]
        BatchEventsInd = EvIndTens[perm][batchInd:batchInd+BatchSize]
        BatchTimes = timeTensor[perm][batchInd:batchInd+BatchSize]
        BatchTMax = tMaxTensor[perm][batchInd:batchInd+BatchSize]
        BatchMask = mask[perm][batchInd:batchInd+BatchSize]
        
        optimizer.zero_grad()
        
        # Do the forward pass
        lambOuts, CLows, Cbars, deltas, OutGates = net.forward(BatchEventsHot, BatchMask, BatchTimes)
        # Calculate the MC loss
        LMC, trandsLMC, t_upLMC = net.MC_Loss(BatchTimes, BatchTMax,
                                      CLows, Cbars, deltas, OutGates,
                                      Nsamples=N_seq_Max)
        # Calculate the likelihood
        LogLikeLoss = net.logLoss(BatchEventsInd, BatchMask, lambOuts)
        
        loss = LogLikeLoss + LMC
        
        loss.backward()
        optimizer.step()

Epoch:21


100%|██████████| 400/400 [10:48<00:00,  1.62s/it]

Epoch:22



100%|██████████| 400/400 [10:36<00:00,  1.59s/it]

Epoch:23



100%|██████████| 400/400 [10:42<00:00,  1.61s/it]

Epoch:24



100%|██████████| 400/400 [10:42<00:00,  1.61s/it]

Epoch:25



100%|██████████| 400/400 [10:39<00:00,  1.60s/it]

Epoch:26



100%|██████████| 400/400 [10:39<00:00,  1.60s/it]

Epoch:27



100%|██████████| 400/400 [10:36<00:00,  1.59s/it]

Epoch:28



100%|██████████| 400/400 [10:40<00:00,  1.60s/it]

Epoch:29



100%|██████████| 400/400 [10:41<00:00,  1.60s/it]

Epoch:30



100%|██████████| 400/400 [10:38<00:00,  1.60s/it]

Epoch:31



100%|██████████| 400/400 [10:37<00:00,  1.59s/it]

Epoch:32



100%|██████████| 400/400 [10:36<00:00,  1.59s/it]

Epoch:33



100%|██████████| 400/400 [10:37<00:00,  1.59s/it]

Epoch:34



100%|██████████| 400/400 [10:32<00:00,  1.58s/it]

Epoch:35



100%|██████████| 400/400 [10:34<00:00,  1.59s/it]

Epoch:36



100%|██████████| 400/400 [10:32<00:00,  1.58s/it]

Epoch:37



100%|██████████| 400/400 [10:33<00:00,  1.58s/it]

Epoch:38



100%|██████████| 400/400 [10:35<00:00,  1.59s/it]

Epoch:39



100%|██████████| 400/400 [10:32<00:00,  1.58s/it]

Epoch:40



100%|██████████| 400/400 [10:35<00:00,  1.59s/it]

Epoch:41



100%|██████████| 400/400 [10:32<00:00,  1.58s/it]

Epoch:42



100%|██████████| 400/400 [10:34<00:00,  1.59s/it]

Epoch:43



100%|██████████| 400/400 [10:39<00:00,  1.60s/it]

Epoch:44



100%|██████████| 400/400 [10:35<00:00,  1.59s/it]

Epoch:45



100%|██████████| 400/400 [10:33<00:00,  1.58s/it]

Epoch:46



100%|██████████| 400/400 [10:55<00:00,  1.64s/it]

Epoch:47



100%|██████████| 400/400 [11:33<00:00,  1.73s/it]

Epoch:48



100%|██████████| 400/400 [11:40<00:00,  1.75s/it]

Epoch:49



100%|██████████| 400/400 [11:37<00:00,  1.74s/it]
