In [None]:
import pickle
import numpy as np
import torch as pt
from tqdm import tqdm

# First, Let's load the data

In [None]:
with open("test_SO.pkl", "rb") as f:
    dat = pickle.load(f, encoding='bytes')

In [None]:
test = dat[b'test']

In [None]:
# lMax = 0
# for i in range(len(train)):
#     if len(train[i]) > lMax:
#         lMax = len(train[i])

# For the stack overflow data set, we truncate each sequence to length 250 max
# So that we can finish training in manageable time
lMax = 250

# Now, let's store everything in Numpy arrays

In [None]:
EventsData = np.ones((len(test), lMax), dtype=int)
timesData = np.zeros((len(test), lMax+1))
timeMaxData = np.zeros(len(test))
SeqLengthData = np.zeros(len(test), dtype=int)

In [None]:
print("Starting Data Processing", flush=True)
for seq in tqdm(range(len(test)), position=0, leave=True):
    if len(test[seq]) > lMax:
        Up = lMax
    else:
        Up = len(test[seq])
        
    for step in range(0, Up):
        dct = test[seq][step]
        event_type = dct[b'type_event']
        time = dct[b'time_since_start']
        
        EventsData[seq, step] = event_type
        timesData[seq, step+1] = time # the first will be stored as zero
    
    timeMaxData[seq] = timesData[seq, step+1] # the max interval of this sequence
    SeqLengthData[seq] = Up
    
    # Now let's fill up remaining events with -1 indicating no event occured
    # and the times with increasing values so that sorting order is not changed
    inc = 0
    for step in range(Up, lMax):
        EventsData[seq, step] = -1
        
        # keep increasing the time so that sorting order is unaffected
        # will help in searching for intervals of random times in MC simulation
        inc += 1
        timesData[seq, step+1] = timeMaxData[seq] + inc

Starting Data Processing


100%|██████████| 1326/1326 [00:00<00:00, 5243.82it/s]


In [None]:
# Now save the arrays into an hdf5 file
# This makes it easier for handling later
import h5py
with h5py.File("SOTestData.h5", "w") as fl:
    fl.create_dataset("EventsData", data = EventsData)
    fl.create_dataset("TimesData", data = timesData)
    fl.create_dataset("TimeMaxData", data = timeMaxData)
    fl.create_dataset("SeqLengthData", data = SeqLengthData)