In [1]:
import pickle
import numpy as np
import torch as pt
from tqdm import tqdm

# First, Let's load the data

In [2]:
with open("train.pkl", "rb") as f:
    dat = pickle.load(f, encoding='bytes')

In [3]:
train = dat[b'train']

In [4]:
lMax = 0
for i in range(len(train)):
    if len(train[i]) > lMax:
        lMax = len(train[i])

In [5]:
len(train), lMax

(20000, 264)

# Now, let's store everything in Numpy arrays

In [6]:
EventsData = np.ones((len(train), lMax), dtype=int)
timesData = np.zeros((len(train), lMax+1))
timeMaxData = np.zeros(len(train))
SeqLengthData = np.zeros(len(train), dtype=int)

In [9]:
print("Starting Data Processing", flush=True)
for seq in tqdm(range(len(train)), position=0, leave=True):
    for step in range(0, len(train[seq])):
        dct = train[seq][step]
        event_type = dct[b'type_event']
        time = dct[b'time_since_start']
        
        EventsData[seq, step] = event_type
        timesData[seq, step+1] = time
    
    timeMaxData[seq] = timesData[seq, step+1] # the max interval of this sequence
    SeqLengthData[seq] = len(train[seq])
    
    # Now let's fill up remaining events with -1
    # and the times with increasing values so that sorting order is not changed
    inc = 0
    for step in range(len(train[seq]), lMax):
        EventsData[seq, step] = -1
        
        # keep increasing the time so that sorting order is unaffected
        # will help in searching for intervals of random times
        inc += 1
        timesData[seq, step+1] = timeMaxData[seq] + inc

Starting Data Processing


100%|██████████| 20000/20000 [00:03<00:00, 5201.83it/s]


In [13]:
timesData[0, 87], train[0][-2]

(39270.0,
 {b'time_since_start': 39270.0,
  b'time_since_last_event': 45.0,
  b'type_event': 1})

In [15]:
assert np.allclose(timesData[:, 0], 0)

In [14]:
# Now save the arrays into an hdf5 file
# This makes it easier for handling later
import h5py
with h5py.File("RetweetTrainData.h5", "w") as fl:
    fl.create_dataset("EventsData", data = EventsData)
    fl.create_dataset("TimesData", data = timesData)
    fl.create_dataset("TimeMaxData", data = timeMaxData)
    fl.create_dataset("SeqLengthData", data = SeqLengthData)