In [1]:
import numpy as np
import datetime
import torch
import torch.utils
from datetime import datetime, timezone


class EventsDataset(torch.utils.data.Dataset):
    '''
    Base class for event datasets
    '''
    def __init__(self, TZ=None):
        self.TZ = TZ  # timezone.utc

    def get_Adjacency(self, multirelations=False):
        return None, None, None

    def __len__(self):
        return self.n_events

    def __getitem__(self, index):

        tpl = self.all_events[index]
        u, v, rel, time_cur = tpl

        # Compute time delta in seconds (t_p - \bar{t}_p_j) that will be fed to W_t
        time_delta_uv = np.zeros((2, 4))  # two nodes x 4 values

        # most recent previous time for all nodes
        time_bar = self.time_bar.copy()
        assert u != v, (tpl, rel)

        for c, j in enumerate([u, v]):
            t = datetime.fromtimestamp(self.time_bar[j], tz=self.TZ)
            if t.toordinal() >= self.FIRST_DATE.toordinal():  # assume no events before FIRST_DATE
                td = time_cur - t
                time_delta_uv[c] = np.array([td.days,  # total number of days, still can be a big number
                                             td.seconds // 3600,  # hours, max 24
                                             (td.seconds // 60) % 60,  # minutes, max 60
                                             td.seconds % 60],  # seconds, max 60
                                            np.float64)
                # assert time_delta_uv.min() >= 0, (index, tpl, time_delta_uv[c], node_global_time[j])
            else:
                raise ValueError('unexpected result', t, self.FIRST_DATE)
            self.time_bar[j] = time_cur.timestamp()  # last time stamp for nodes u and v

        k = self.event_types_num[rel]

        # sanity checks
        assert np.float64(time_cur.timestamp()) == time_cur.timestamp(), (
        np.float64(time_cur.timestamp()), time_cur.timestamp())
        time_cur = np.float64(time_cur.timestamp())
        time_bar = time_bar.astype(np.float64)
        time_cur = torch.from_numpy(np.array([time_cur])).double()
        # Print details if assertion fails
        if time_bar.max() > time_cur:
            print(f"Assertion Error Details: index={index}, tpl={tpl}, u={u}, v={v}, rel={rel}, time_cur={time_cur}, time_bar={time_bar}")
        assert time_bar.max() <= time_cur, (time_bar.max(), time_cur)
        return u, v, time_delta_uv, k, time_bar, time_cur

In [2]:
import matplotlib.pyplot as plt
import os
from os.path import join as pjoin
import numpy as np
import datetime
import pickle
import pandas
import itertools
import torch
import torch.utils


class SocialEvolutionDataset(EventsDataset):
    '''
    Class to load batches for training and testing
    '''

    FIRST_DATE = datetime.datetime(2008, 9, 11)  # consider events starting from this time
    EVENT_TYPES =  ['SMS', 'Proximity', 'Calls']

    def __init__(self,
                 subj_features,
                 data,
                 MainAssociation,
                 data_train=None,
                 verbose=False):
        super(SocialEvolutionDataset, self).__init__()

        self.subj_features = subj_features
        self.data = data
        self.verbose = verbose
        self.all_events = []
        self.event_types_num = {}
        # self.time_bar = None
        self.MainAssociation = MainAssociation
        self.TEST_TIMESLOTS = [datetime.datetime(2009, 5, 10), datetime.datetime(2009, 5, 20), datetime.datetime(2009, 5, 31),
                               datetime.datetime(2009, 6, 10), datetime.datetime(2009, 6, 20), datetime.datetime(2009, 6, 30)]
        self.FIRST_DATE = SocialEvolutionDataset.FIRST_DATE
        self.event_types = SocialEvolutionDataset.EVENT_TYPES

        k = 1  # k >= 1 for communication events
        print(data.split.upper())
        for t in self.event_types:
            print('Event type={}, k={}, number of events={}'.format(t, k, len(data.EVENT_TYPES[t].tuples)))

            events = list(filter(lambda x: x[3].toordinal() >= self.FIRST_DATE.toordinal(),
                                 data.EVENT_TYPES[t].tuples))
            self.all_events.extend(events)
            self.event_types_num[t] = k
            k += 1

        n = len(self.all_events)
        self.N_nodes = subj_features.shape[0]

        if data.split == 'train':
            Adj_all, keys, Adj_all_last = self.get_Adjacency()

            if self.verbose:
                print('initial and final associations', self.MainAssociation, Adj_all.sum(), Adj_all_last.sum(),
                      np.allclose(Adj_all, Adj_all_last))


        # Initial topology
        if len(list(data.Adj.keys())) > 0:

            keys = sorted(list(data.Adj[list(data.Adj.keys())[0]].keys()))  # relation keys
            keys.remove(MainAssociation)
            keys = [MainAssociation] + keys  # to make sure CloseFriend goes first

            k = 0  # k <= 0 for association events
            for rel in keys:

                if rel != MainAssociation:
                    continue
                if data_train is None:
                    date = sorted(list(data.Adj.keys()))[0]  # first date
                    Adj_prev = data.Adj[date][rel]
                else:
                    date = sorted(list(data_train.Adj.keys()))[-1]  # last date of the training set
                    Adj_prev = data_train.Adj[date][rel]
                self.event_types_num[rel] = k

                N = Adj_prev.shape[0]

                # Associative events
                for date_id, date in enumerate(sorted(list(data.Adj.keys()))):  # start from the second survey
                    if date.toordinal() >= self.FIRST_DATE.toordinal():
                        # for rel_id, rel in enumerate(sorted(list(dygraphs.Adj[date].keys()))):
                        assert data.Adj[date][rel].shape[0] == N
                        for u in range(N):
                            for v in range(u + 1, N):
                                # if two nodes become friends, add the event
                                if data.Adj[date][rel][u, v] > 0 and Adj_prev[u, v] == 0:
                                    assert u != v, (u, v, k)
                                    self.all_events.append((u, v, rel, date))

                    Adj_prev = data.Adj[date][rel]

                # print(data.split, rel, len(self.all_events) - n)
                print('Event type={}, k={}, number of events={}'.format(rel, k, len(self.all_events) - n))
                n = len(self.all_events)
                k -= 1

        self.all_events = sorted(self.all_events, key=lambda x: int(x[3].timestamp()))

        if self.verbose:
            print('%d events' % len(self.all_events))
            print('last 10 events:')
            for event in self.all_events[-10:]:
                print(event)

        self.n_events = len(self.all_events)

        H_train = np.zeros((N, N))
        c = 0
        for e in self.all_events:
            H_train[e[0], e[1]] += 1
            H_train[e[1], e[0]] += 1
            c += 1
        if self.verbose:
            print('H_train', c, H_train.max(), H_train.min(), H_train.std())
        self.H_train = H_train

        self.time_bar = np.full(self.N_nodes, self.FIRST_DATE.timestamp())


    @staticmethod
    def load_data(data_dir, prob, dump=True):
        data_file = pjoin(data_dir, 'data_prob%s.pkl' % prob)
        if os.path.isfile(data_file):
            print('loading data from %s' % data_file)
            with open(data_file, 'rb') as f:
                data = pickle.load(f)
        else:
            data = {'initial_embeddings': SubjectsReader(pjoin(data_dir, 'Subjects.csv')).features_onehot}
            for split in ['train', 'test']:
                data.update(
                    {split: SocialEvolution(data_dir, split=split, MIN_EVENT_PROB=prob)})
            if dump:
                # dump data files to avoid their generation again
                print('saving data to %s' % data_file)
                with open(data_file, 'wb') as f:
                    pickle.dump(data, f, protocol=2)  # for compatibility
        return data

    def get_Adjacency(self, multirelations=False):
        dates = sorted(list(self.data.Adj.keys()))
        Adj_all = self.data.Adj[dates[0]]
        Adj_all_last = self.data.Adj[dates[-1]]
        # Adj_friends = Adj_all[self.MainAssociation].copy()
        if multirelations:
            keys = sorted(list(Adj_all.keys()))
            keys.remove(self.MainAssociation)
            keys = [self.MainAssociation] + keys  # to make sure CloseFriend goes first
            Adj_all = np.stack([Adj_all[rel].copy() for rel in keys], axis=2)
            Adj_all_last = np.stack([Adj_all_last[rel].copy() for rel in keys], axis=2)
        else:
            keys = [self.MainAssociation]
            Adj_all = Adj_all[self.MainAssociation].copy()
            Adj_all_last = Adj_all_last[self.MainAssociation].copy()

        return Adj_all, keys, Adj_all_last


    def time_to_onehot(self, d):
        x = []
        for t, max_t in [(d.weekday(), 7), (d.hour, 24), (d.minute, 60), (d.second, 60)]:
            x_t = np.zeros(max_t)
            x_t[t] = 1
            x.append(x_t)
        return np.concatenate(x)

class CSVReader:
    '''
    General class to read any relationship csv in this dataset
    '''

    def __init__(self,
                 csv_path,
                 split,  # 'train', 'test', 'all'
                 MIN_EVENT_PROB,
                 event_type=None,
                 N_subjects=None,
                 test_slot=1):
        self.csv_path = csv_path
        print(os.path.basename(csv_path))

        if split == 'train':
            time_start = 0
            time_end = datetime.datetime(2009, 4, 30).toordinal()
        elif split == 'test':
            if test_slot != 1:
                raise NotImplementedError('test on time slot 1 for now')
            time_start = datetime.datetime(2009, 5, 1).toordinal()
            time_end = datetime.datetime(2009, 6, 30).toordinal()
        else:
            time_start = 0
            time_end = np.Inf

        csv = pandas.read_csv(csv_path)
        self.data = {}
        to_date1 = lambda s: datetime.datetime.strptime(s, '%Y-%m-%d')
        to_date2 = lambda s: datetime.datetime.strptime(s, '%Y-%m-%d %H:%M:%S')
        user_columns = list(filter(lambda c: c.find('user') >= 0 or c.find('id') >= 0, list(csv.keys())))
        assert len(user_columns) == 2, (list(csv.keys()), user_columns)
        self.time_column = list(filter(lambda c: c.find('time') >= 0 or c.find('date') >= 0, list(csv.keys())))
        assert len(self.time_column) == 1, (list(csv.keys()), self.time_column)
        self.time_column = self.time_column[0]

        self.prob_column = list(filter(lambda c: c.find('prob') >= 0, list(csv.keys())))

        for column in list(csv.keys()):
            values = csv[column].tolist()
            for fn in [int, float, to_date1, to_date2]:
                try:
                    values = list(map(fn, values))
                    break
                except Exception as e:
                    continue
            self.data[column] = values

        n_rows = len(self.data[self.time_column])

        time_stamp_days = np.array([d.toordinal() for d in self.data[self.time_column]], dtype=np.int)

        # skip data where one of users is missing (nan) or interacting with itself or timestamp not in range
        conditions = [~np.isnan(self.data[user_columns[0]]),
                      ~np.isnan(self.data[user_columns[1]]),
                      np.array(self.data[user_columns[0]]) != np.array(self.data[user_columns[1]]),
                      time_stamp_days >= time_start,
                      time_stamp_days <= time_end]

        if len(self.prob_column) == 1:
            print(split, event_type, self.prob_column)
            # skip data if the probability of event is 0 or nan (available for some event types)
            conditions.append(np.nan_to_num(np.array(self.data[self.prob_column[0]])) > MIN_EVENT_PROB)

        valid_ids = np.ones(n_rows, dtype=np.bool)
        for cond in conditions:
            valid_ids = valid_ids & cond

        self.valid_ids = np.where(valid_ids)[0]

        time_stamps_sec = [self.data[self.time_column][i].timestamp() for i in self.valid_ids]
        self.valid_ids = self.valid_ids[np.argsort(time_stamps_sec)]

        print(split, len(self.valid_ids), n_rows)

        for column in list(csv.keys()):
            values = csv[column].tolist()
            key = column + '_unique'
            for fn in [int, float, to_date1, to_date2]:
                try:
                    values = list(map(fn, values))
                    break
                except Exception as e:
                    continue

            self.data[column] = values

            values_valid = [values[i] for i in self.valid_ids]
            self.data[key] = np.unique(values_valid)
            print(key, type(values[0]), len(self.data[key]), self.data[key])

        self.subjects, self.time_stamps = [], []
        for usr_col in range(len(user_columns)):
            self.subjects.extend([self.data[user_columns[usr_col]][i] for i in self.valid_ids])
            self.time_stamps.extend([self.data[self.time_column][i] for i in self.valid_ids])

        # set O={(u, v, k, t)}
        self.tuples = []
        if N_subjects is not None:
            # Compute frequency of communcation between users
            print('user_columns', user_columns)
            self.Adj = np.zeros((N_subjects, N_subjects))
            for row in self.valid_ids:
                subj1 = self.data[user_columns[0]][row]
                subj2 = self.data[user_columns[1]][row]

                assert subj1 != subj2, (subj1, subj2)
                assert subj1 > 0 and subj2 > 0, (subj1, subj2)
                try:
                    self.Adj[int(subj1) - 1, int(subj2) - 1] += 1
                    self.Adj[int(subj2) - 1, int(subj1) - 1] += 1
                except:
                    print(subj1, subj2)
                    raise

                self.tuples.append((int(subj1) - 1,
                                    int(subj2) - 1,
                                    event_type,
                                    self.data[self.time_column][row]))

        n1 = len(self.tuples)
        self.tuples = list(set(itertools.chain(self.tuples)))
        self.tuples = sorted(self.tuples, key=lambda t: t[3].timestamp())
        n2 = len(self.tuples)
        print('%d/%d duplicates removed' % (n1 - n2, n1))


class SubjectsReader:
    '''
    Class to read Subjects.csv in this dataset
    '''

    def __init__(self,
                 csv_path):
        self.csv_path = csv_path
        print(os.path.basename(csv_path))

        csv = pandas.read_csv(csv_path)
        subjects = csv[list(filter(lambda column: column.find('user') >= 0, list(csv.keys())))[0]].tolist()
        print('Number of subjects', len(subjects))
        features = []
        for column in list(csv.keys()):
            if column.find('user') >= 0:
                continue
            values = list(map(str, csv[column].tolist()))
            features_unique = np.unique(values)
            features_onehot = np.zeros((len(subjects), len(features_unique)))
            for subj, feat in enumerate(values):
                ind = np.where(features_unique == feat)[0]
                assert len(ind) == 1, (ind, features_unique, feat, type(feat))
                features_onehot[subj, ind[0]] = 1
            features.append(features_onehot)

        features_onehot = np.concatenate(features, axis=1)
        print('features', features_onehot.shape)
        self.features_onehot = features_onehot


class SocialEvolution():
    '''
    Class to read all csv in this dataset
    '''

    def __init__(self,
                 data_dir,
                 split,
                 MIN_EVENT_PROB):
        self.data_dir = data_dir
        self.split = split
        self.MIN_EVENT_PROB = MIN_EVENT_PROB

        self.relations = CSVReader(pjoin(data_dir, 'RelationshipsFromSurveys.csv'), split=split, MIN_EVENT_PROB=MIN_EVENT_PROB)
        self.relations.subject_ids = np.unique(self.relations.data['id.A'] + self.relations.data['id.B'])
        self.N_subjects = len(self.relations.subject_ids)
        print('Number of subjects', self.N_subjects)

        # Read communicative events
        self.EVENT_TYPES = {}
        for t in SocialEvolutionDataset.EVENT_TYPES:
            self.EVENT_TYPES[t] = CSVReader(pjoin(data_dir, '%s.csv' % t),
                                           split=split,
                                           MIN_EVENT_PROB=MIN_EVENT_PROB,
                                           event_type=t,
                                           N_subjects=self.N_subjects)

        # Compute adjacency matrices for associative relationship data
        self.Adj = {}
        dates = self.relations.data['survey.date']
        rels = self.relations.data['relationship']
        for date_id, date in enumerate(self.relations.data['survey.date_unique']):
            self.Adj[date] = {}
            ind = np.where(np.array([d == date for d in dates]))[0]
            for rel_id, rel in enumerate(self.relations.data['relationship_unique']):
                ind_rel = np.where(np.array([r == rel for r in [rels[i] for i in ind]]))[0]
                A = np.zeros((self.N_subjects, self.N_subjects))
                for j in ind_rel:
                    row = ind[j]
                    A[self.relations.data['id.A'][row] - 1, self.relations.data['id.B'][row] - 1] = 1
                    A[self.relations.data['id.B'][row] - 1, self.relations.data['id.A'][row] - 1] = 1
                self.Adj[date][rel] = A
                # sanity check
                for row in range(len(dates)):
                    if rels[row] == rel and dates[row] == date:
                        assert self.Adj[dates[row]][rels[row]][
                                   self.relations.data['id.A'][row] - 1, self.relations.data['id.B'][row] - 1] == 1
                        assert self.Adj[dates[row]][rels[row]][
                                   self.relations.data['id.B'][row] - 1, self.relations.data['id.A'][row] - 1] == 1

In [3]:
# Paths to the dataset files
data_dir = '/Users/amberrrrrr/Desktop/trials/dyrep_torch-main/SocialEvolution/'
prob = 0.8
association = 'CloseFriend'

# Load the data
data = SocialEvolutionDataset.load_data(data_dir, prob)

# Initialize train and test sets
train_set = SocialEvolutionDataset(data['initial_embeddings'], data['train'], association, verbose=False)
test_set = SocialEvolutionDataset(data['initial_embeddings'], data['test'], association, data_train=data['train'], verbose=False)

# Preview the first few lines of the train set and test set
print("Train set preview (first 5 events):")
for event in train_set.all_events[:5]:
    print(event)

print("\nTest set preview (first 5 events):")
for event in test_set.all_events[:5]:
    print(event)

loading data from /Users/amberrrrrr/Desktop/trials/dyrep_torch-main/SocialEvolution/data_prob0.8.pkl
TRAIN
Event type=SMS, k=1, number of events=4319
Event type=Proximity, k=2, number of events=31011
Event type=Calls, k=3, number of events=8187
Event type=CloseFriend, k=0, number of events=365
TEST
Event type=SMS, k=1, number of events=288
Event type=Proximity, k=2, number of events=9094
Event type=Calls, k=3, number of events=1080
Event type=CloseFriend, k=0, number of events=73
Train set preview (first 5 events):
(42, 50, 'Calls', datetime.datetime(2008, 9, 11, 3, 16, 14))
(42, 50, 'Calls', datetime.datetime(2008, 9, 19, 0, 31, 33))
(42, 21, 'Calls', datetime.datetime(2008, 9, 19, 0, 58, 2))
(42, 54, 'Calls', datetime.datetime(2008, 9, 19, 1, 21, 4))
(42, 50, 'Calls', datetime.datetime(2008, 9, 19, 18, 20, 43))

Test set preview (first 5 events):
(0, 60, 'Proximity', datetime.datetime(2009, 5, 1, 0, 3, 29))
(60, 0, 'Proximity', datetime.datetime(2009, 5, 1, 0, 3, 51))
(59, 66, 'Proxi

In [4]:
initial_embeddings = data['initial_embeddings'].copy()
A_initial = train_set.get_Adjacency()[0]

In [5]:
#initialise the A and z 
N_nodes = A_initial.shape[0]
if A_initial is not None and len(A_initial.shape) == 2:
    A_initial = A_initial[:, :, None]
n_assoc_types,n_event_types = 1, 3
n_relations = n_assoc_types + n_event_types

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def generate_S_from_A(A,N_nodes,n_assoc_types):
    # S = A.new(N_nodes, N_nodes, n_assoc_types).fill_(0)
    # for rel in range(n_assoc_types):
    #     D = torch.sum(A[:, :, rel], dim=1).float()
    #     for v in torch.nonzero(D):
    #         u = torch.nonzero(A[v, :, rel].squeeze())
    #         S[v, u, rel] = 1. / D[v]
    # # Check that values in each row of S add up to 1
    # for rel in range(n_assoc_types):
    #     S1 = S[:, :, rel]
    #     assert torch.sum(S1[A[:, :, rel] == 0]) < 1e-5, torch.sum(S1[A[:, :, rel] == 0]) 
    # return S
    if isinstance(A, np.ndarray):
        A = torch.tensor(A, dtype=torch.float32)  # Convert A to a tensor if it's a numpy array
    S = A.new_empty(N_nodes, N_nodes, n_assoc_types).fill_(0)
    for rel in range(n_assoc_types):
        D = torch.sum(A[:, :, rel], dim=1).float()
        for i, v in enumerate(torch.nonzero(D, as_tuple=False).squeeze()):
            u = torch.nonzero(A[v, :, rel].squeeze(), as_tuple=False).squeeze()
            S[v, u, rel] = 1. / D[v]
    # Check that values in each row of S add up to 1
    for rel in range(n_assoc_types):
        S1 = S[:, :, rel]
        assert torch.sum(S1[A[:, :, rel] == 0]) < 1e-5, torch.sum(S1[A[:, :, rel] == 0])
    return S

In [7]:
generate_S_from_A(A_initial,N_nodes,n_assoc_types)

tensor([[[0.0000],
         [0.0000],
         [0.0000],
         ...,
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.3333],
         [0.0000],
         ...,
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         ...,
         [0.0000],
         [0.0000],
         [0.0000]],

        ...,

        [[0.0000],
         [0.0000],
         [0.0000],
         ...,
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         ...,
         [0.0000],
         [0.0000],
         [0.0000]],

        [[0.0000],
         [0.0000],
         [0.0000],
         ...,
         [0.0000],
         [0.0000],
         [0.0000]]])

In [8]:
def initialize(node_embeddings, A_initial,N_nodes, n_hidden,n_assoc_types, sparse=False, keepS=False):
    print('initialize model''s node embeddings and adjacency matrices for %d nodes' % N_nodes)
    # Initial embeddings
    if node_embeddings is not None:
        z = np.pad(node_embeddings, ((0, 0), (0, n_hidden - node_embeddings.shape[1])), 'constant')
        z = torch.from_numpy(z).float()

    if A_initial is None:
        print('initial random prediction of A')
        A = torch.zeros(N_nodes, N_nodes, n_assoc_types + int(sparse))

        for i in range(N_nodes):
            for j in range(i + 1, N_nodes):
                if sparse:
                    if n_assoc_types == 1:
                        pvals = [0.95, 0.05]
                    elif n_assoc_types == 2:
                        pvals = [0.9, 0.05, 0.05]
                    elif n_assoc_types == 3:
                        pvals = [0.91, 0.03, 0.03, 0.03]
                    elif n_assoc_types == 4:
                        pvals = [0.9, 0.025, 0.025, 0.025, 0.025]
                    else:
                        raise NotImplementedError(n_assoc_types)
                    ind = np.nonzero(np.random.multinomial(1, pvals))[0][0]
                else:
                    ind = np.random.randint(0, n_assoc_types, size=1)
                A[i, j, ind] = 1
                A[j, i, ind] = 1
        assert torch.sum(torch.isnan(A)) == 0, (torch.sum(torch.isnan(A)), A)
        if sparse:
            A = A[:, :, 1:]

    else:
        print('A_initial', A_initial.shape)
        A = torch.from_numpy(A_initial).float()
        if len(A.shape) == 2:
            A = A.unsqueeze(2)

    # make these variables part of the model
    # register_buffer('z', z)
    # register_buffer('A', A)


    if not keepS:
        generate_S_from_A(A,N_nodes,n_assoc_types)

    Lambda_dict = torch.zeros(5000)
    time_keys = []

    t_p = 0  # global counter of iterations
    return z, A, Lambda_dict, time_keys, t_p

In [9]:
z,A_ini,Lambda_dict, time_keys, t_p = initialize(node_embeddings=initial_embeddings, A_initial=A_initial,N_nodes=N_nodes, n_hidden=32,n_assoc_types=n_assoc_types, sparse=False, keepS=False)

initialize models node embeddings and adjacency matrices for 84 nodes
A_initial (84, 84, 1)


In [10]:
z.shape

torch.Size([84, 32])

In [11]:
time_keys

[]

In [12]:
def init_weights(modules):
    for m in modules:
        if isinstance(m, nn.Linear) or isinstance(m, nn.Bilinear):
            # print('before Xavier', m.weight.data.shape, m.weight.data.min(), m.weight.data.max())
            nn.init.xavier_normal_(m.weight.data)

In [13]:
# initialize the parameters of the model
n_hidden = 32
W_h = nn.Linear(in_features=n_hidden, out_features=n_hidden)
W_struct = nn.Linear(n_hidden * n_assoc_types, n_hidden)
W_rec = nn.Linear(n_hidden, n_hidden)
W_t = nn.Linear(4, n_hidden) # 4 because we want separate parameters for days, hours, minutes, seconds; otherwise (if we just use seconds) it can be a huge number confusing the network
n_types = 2  # associative and communicative
d1 = n_hidden + (0)
d2 = n_hidden + (0)
d1 += n_hidden
d2 += n_hidden
omega = nn.ModuleList([nn.Linear(d1, 1), nn.Linear(d2, 1)])
psi = nn.Parameter(0.5 * torch.ones(n_types))
# print('omega', omega)
# Initialize weights
init_weights([W_h, W_struct, W_rec, W_t] + [layer for layer in omega])
for layer in [W_h, W_struct, W_rec, W_t] + [layer for layer in omega]:
    print(f'Layer: {layer.__class__.__name__}, Weights shape: {layer.weight.shape}')


Layer: Linear, Weights shape: torch.Size([32, 32])
Layer: Linear, Weights shape: torch.Size([32, 32])
Layer: Linear, Weights shape: torch.Size([32, 32])
Layer: Linear, Weights shape: torch.Size([32, 4])
Layer: Linear, Weights shape: torch.Size([1, 64])
Layer: Linear, Weights shape: torch.Size([1, 64])


In [14]:
def check_S(A, S, n_assoc_types):
        for rel in range(n_assoc_types):
            rows = torch.nonzero(torch.sum(A[:, :, rel], dim=1).float())
            # check that the sum in all rows equal 1
            assert torch.all(torch.abs(torch.sum(S[:, :, rel], dim=1)[rows] - 1) < 1e-1), torch.abs(torch.sum(S[:, :, rel], dim=1)[rows] - 1)
S_ini = generate_S_from_A(A_initial,N_nodes,n_assoc_types)
# check that the sum in all rows equal 1
check_S(A_ini, S_ini, n_assoc_types)

In [15]:
# print shape of S and A:
print(f'S shape: {S_ini.shape}')
print(f'A shape: {A_ini.shape}')

S shape: torch.Size([84, 84, 1])
A shape: torch.Size([84, 84, 1])


In [16]:
# define functions to construct lambda
def g_fn(z_cat, k, n_assoc_types, omega, edge_type=None, z2=None):
    if z2 is not None:
        z_cat = torch.cat((z_cat, z2), dim=1)
    else:
        raise NotImplementedError('')
    g = z_cat.new(len(z_cat), 1).fill_(0)
    idx = k <= 0
    if torch.sum(idx) > 0:
        if edge_type is not None:
            z_cat1 = torch.cat((z_cat[idx], edge_type[idx, :n_assoc_types]), dim=1)
        else:
            z_cat1 = z_cat[idx]
        g[idx] = omega[0](z_cat1)
    idx = k > 0
    if torch.sum(idx) > 0:
        if edge_type is not None:
            z_cat1 = torch.cat((z_cat[idx], edge_type[idx, n_assoc_types:]), dim=1)
        else:
            z_cat1 = z_cat[idx]
        g[idx] = omega[1](z_cat1)

    g = g.flatten()
    return g

In [17]:
def intensity_rate_lambda(z_u, z_v, k,n_hidden,psi,n_assoc_types,omega,edge_type=None):
        z_u = z_u.view(-1, n_hidden).contiguous()
        z_v = z_v.view(-1, n_hidden).contiguous()
        edge_type = None
        g = 0.5 * (g_fn(z_u, (k > 0).long(),n_assoc_types,omega, edge_type=edge_type, z2=z_v) + g_fn(z_v, (k > 0).long(),n_assoc_types,omega, edge_type=edge_type, z2=z_u))  # make it symmetric, because most events are symmetric
        psi = psi[(k > 0).long()]
        g_psi = torch.clamp(g / (psi + 1e-7), -75, 75)  # to prevent overflow
        Lambda = psi * (torch.log(1 + torch.exp(-g_psi)) + g_psi)
        return Lambda

In [18]:
# test that the lambda function works
z_u = z[0]
z_v = z[1]
k = torch.tensor(1)
lmd= intensity_rate_lambda(z_u, z_v, k,n_hidden,psi,n_assoc_types,omega)
print(lmd)

tensor([0.5158], grad_fn=<MulBackward0>)


In [19]:
# define the update rules for Z, A and S, Omega(the weights of the model)
# 1) for embedding
def update_node_embed(prev_embed, node1, node2, time_delta_uv, n_hidden,n_assoc_types, S, A, W_h, W_struct, W_rec, W_t):
    # z contains all node embeddings of previous time \bar{t}
    # S also corresponds to previous time stamp, because it's not updated yet based on this event

    node_embed = prev_embed

    node_degree = {} # we need degrees to update S
    z_new = prev_embed.clone()  # to allow in place changes while keeping gradients
    h_u_struct = prev_embed.new_zeros((2, n_hidden, n_assoc_types))
    for c, (v, u, delta_t) in enumerate(zip([node1, node2], [node2, node1], time_delta_uv)):  # i is the other node involved in the event
        node_degree[u] = np.zeros(n_assoc_types)
        for rel in range(n_assoc_types):
            Neighb_u = A[u, :, rel] > 0  # when update embedding for node v, we need neighbors of u and vice versa!
            N_neighb = torch.sum(Neighb_u).item()  # number of neighbors for node u
            node_degree[u][rel] = N_neighb
            if N_neighb > 0:  # node has no neighbors
                h_prev_i = W_h(node_embed[Neighb_u]).view(N_neighb, n_hidden)
                # attention over neighbors
                q_ui = torch.exp(S[u, Neighb_u, rel]).view(N_neighb, 1)
                q_ui = q_ui / (torch.sum(q_ui) + 1e-7)
                h_u_struct[c, :, rel] = torch.max(torch.sigmoid(q_ui * h_prev_i), dim=0)[0].view(1, n_hidden)

    h1 = W_struct(h_u_struct.view(2, n_hidden * n_assoc_types))

    h2 = W_rec(node_embed[[node1, node2], :].view(2, -1))
    h3 = W_t(time_delta_uv.float()).view(2, n_hidden)

    z_new[[node1, node2], :] = torch.sigmoid(h1 + h2 + h3)

    return node_degree, z_new

In [20]:
# test the update_node_embed function
time_delta_uv = torch.tensor([[1.0, 2.0, 0.0, 0.0], [2.0, 1.0, 0.0, 0.0]])
update_node_embed(prev_embed =z, node1=0, node2=1, time_delta_uv=time_delta_uv, n_hidden=32, n_assoc_types=1, S=S_ini, A=A_ini, W_h=W_h, W_struct=W_struct, W_rec=W_rec, W_t=W_t)

({1: array([3.]), 0: array([2.])},
 tensor([[0.3025, 0.2218, 0.3775,  ..., 0.3073, 0.3382, 0.4307],
         [0.2332, 0.3270, 0.4074,  ..., 0.4338, 0.3797, 0.5263],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
        grad_fn=<CopySlices>))

In [21]:
node_deg,z_cur = update_node_embed(prev_embed =z, node1=0, node2=1, time_delta_uv=time_delta_uv, n_hidden=32, n_assoc_types=1, S=S_ini, A=A_ini, W_h=W_h, W_struct=W_struct, W_rec=W_rec, W_t=W_t)

In [22]:
Adj_all = train_set.get_Adjacency()[0]

if not isinstance(Adj_all, list):
    Adj_all = [Adj_all]

node_degree_global = []
for rel, A in enumerate(Adj_all):
    node_degree_global.append(np.zeros(A.shape[0]))
    for u in range(A.shape[0]):
        node_degree_global[rel][u] = np.sum(A[u])

Adj_all = Adj_all[0]

print('Adj_all', Adj_all.shape, len(node_degree_global), node_degree_global[0].min(), node_degree_global[0].max())
time_bar = np.zeros((N_nodes, 1)) + train_set.FIRST_DATE.timestamp()

Adj_all (84, 84) 1 0.0 21.0


In [23]:
print('time_bar', time_bar.shape)
print(node_degree_global)

time_bar (84, 1)
[array([ 2.,  3.,  0., 12.,  4., 18.,  5.,  3., 19.,  5., 18.,  0.,  0.,
       11.,  0.,  9.,  3., 19., 16.,  8., 13., 12., 21., 11.,  7.,  9.,
       10., 20.,  5., 12., 13.,  6.,  6.,  8.,  4.,  9.,  1.,  1.,  7.,
        3.,  0.,  6.,  2., 18.,  3., 13., 11.,  7., 13.,  3., 11.,  4.,
        9.,  3.,  8.,  5.,  8.,  7.,  0.,  4.,  2.,  2.,  9.,  3.,  7.,
        5., 10.,  6.,  3., 10., 11.,  9.,  8.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  7.,  3.,  5.,  7.])]


In [24]:
# 2) for S and A
def update_S_A(A,S, u, v, k, node_degree, lambda_uv_t, N_nodes,n_assoc_types,node_degree_global):
    if k <= 0 :  # Association event
        # do not update in case of latent graph
        A[u, v, np.abs(k)] = A[v, u, np.abs(k)] = 1  # 0 for CloseFriends, k = -1 for the second relation, so it's abs(k) matrix in self.A
    indices = torch.arange(N_nodes)
    for rel in range(n_assoc_types):
        if k > 0 and A[u, v, rel] == 0:  # Communication event, no Association exists
            continue  # do not update S and A
        else:
            for j, i in zip([u, v], [v, u]):
                # i is the "other node involved in the event"
                try:
                    degree = node_degree[j]
                except:
                    print(list(node_degree.keys()))
                    raise
                y = S[j, :, rel]
                # assert torch.sum(torch.isnan(y)) == 0, ('b', j, degree[rel], node_degree_global[rel][j.item()], y)
                b = 0 if degree[rel] == 0 else 1. / (float(degree[rel]) + 1e-7)
                if k > 0 and A[u, v, rel] > 0:  # Communication event, Association exists
                    y[i] = b + lambda_uv_t
                elif k <= 0 and A[u, v, rel] > 0:  # Association event
                    if node_degree_global[rel][j] == 0:
                        b_prime = 0
                    else:
                        b_prime = 1. / (float(node_degree_global[rel][j]) + 1e-7)
                    x = b_prime - b
                    y[i] = b + lambda_uv_t
                    w = (y != 0) & (indices != int(i))
                    y[w] = y[w] - x
                y /= (torch.sum(y) + 1e-7)  # normalize
                S[j, :, rel] = y
    return S, A

In [25]:
# 2) for S and A
def update_S_A_new(A,S, u, v, k, node_degree, lambda_uv_t, N_nodes,n_assoc_types,node_degree_global,decay_rate=0.1,threshold=0.5):
    if k <= 0 :  # Association event
        # do not update in case of latent graph
        A[u, v, np.abs(k)] = A[v, u, np.abs(k)] = 1  # 0 for CloseFriends, k = -1 for the second relation, so it's abs(k) matrix in self.A
    indices = torch.arange(N_nodes)
    for rel in range(n_assoc_types):
        #if k > 0 and A[u, v, rel] != 0:  # Communication event, no Association exists
        for j, i in zip([u, v], [v, u]):
            # i is the "other node involved in the event"
            try:
                degree = node_degree[j]
            except:
                print(list(node_degree.keys()))
                raise
            y = S[j, :, rel]
            y = y * (1 - decay_rate)
            # assert torch.sum(torch.isnan(y)) == 0, ('b', j, degree[rel], node_degree_global[rel][j.item()], y)
            b = 0 if degree[rel] == 0 else 1. / (float(degree[rel]) + 1e-7)
            if k > 0 and A[u, v, rel] > 0:  # Communication event, Association exists
                y[i] = b + lambda_uv_t
            elif k <= 0 and A[u, v, rel] > 0:  # Association event
                if node_degree_global[rel][j] == 0:
                    b_prime = 0
                else:
                    b_prime = 1. / (float(node_degree_global[rel][j]) + 1e-7)
                x = b_prime - b
                y[i] = b + lambda_uv_t
                w = (y != 0) & (indices != int(i))
                y[w] = y[w] - x
            y /= (torch.sum(y) + 1e-7)  # normalize
            S[j, :, rel] = y
        if k > 0 and A[u, v, rel] != 0:
            if y >= threshold:
                A[u, v, rel] = A[v, u, rel] = 1
    return S, A

In [26]:
# test the update_S_A function
lambda_uv_t = torch.tensor(0.5)
update_S_A(A_ini,S_ini, u=0, v=1, k=1, node_degree=node_deg, lambda_uv_t=lambda_uv_t, N_nodes=N_nodes,n_assoc_types=n_assoc_types,node_degree_global=node_degree_global)

(tensor([[[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.3333],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         ...,
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]]]),
 tensor([[[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [1.],
          

In [27]:
# test the update_S_A function
lambda_uv_t = torch.tensor(0.5)
update_S_A_new(A_ini,S_ini, u=0, v=1, k=1, node_degree=node_deg, lambda_uv_t=lambda_uv_t, N_nodes=N_nodes,n_assoc_types=n_assoc_types,node_degree_global=node_degree_global,decay_rate=0.1,threshold=0.5)

(tensor([[[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.3333],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         ...,
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]],
 
         [[0.0000],
          [0.0000],
          [0.0000],
          ...,
          [0.0000],
          [0.0000],
          [0.0000]]]),
 tensor([[[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [1.],
          

In [28]:
# conditional density calculation to predict the next event (the probability of the next event for each pair of nodes)
def cond_density(time_bar, u, v,N_nodes, Lambda_dict, time_keys):
    N = N_nodes
    if not time_keys:  # Checks if time_keys is empty
        print("Warning: time_keys is empty. No operations performed.")
        return torch.zeros((2, N_nodes)) 
    s = Lambda_dict.new_zeros((2, N))
    #normalize lambda values by dividing by the number of events
    Lambda_sum = torch.cumsum(Lambda_dict.flip(0), 0).flip(0)  / len(Lambda_dict)
    time_keys_min = time_keys[0]
    time_keys_max = time_keys[-1]

    indices = []
    l_indices = []
    t_bar_min = torch.min(time_bar[[u, v]]).item()
    if t_bar_min < time_keys_min:
        start_ind_min = 0
    elif t_bar_min > time_keys_max:
        # it means t_bar will always be larger, so there is no history for these nodes
        return s
    else:
        start_ind_min = time_keys.index(int(t_bar_min))

    # print("time_bar shape:", time_bar.shape)
    # print("Expanded and reshaped time_bar shape:", time_bar[[u, v]].view(1, 2).expand(N, -1).t().contiguous().view(2 * N, 1).shape)
    # print("Repeated time_bar shape:", time_bar.repeat(2, 1).shape)
    # Reshape expanded and reshaped time_bar
    expanded_time_bar = time_bar[[u, v]].view(1, 2).expand(N, -1).t().contiguous().view(2 * N, 1)
    # Adjust repeated time_bar to match the expanded shape
    adjusted_repeated_time_bar = time_bar.repeat(2, 1).view(2 * N, 1)
    # Now concatenate along dimension 1 (should work as both tensors are (168, 1))
    max_pairs = torch.max(torch.cat((expanded_time_bar, adjusted_repeated_time_bar), dim=1), dim=1)[0].view(2, N).long()
    # max_pairs = torch.max(torch.cat((time_bar[[u, v]].view(1, 2).expand(N, -1).t().contiguous().view(2 * N, 1),
    #                                     time_bar.repeat(2, 1)), dim=1), dim=1)[0].view(2, N).long().data.cpu().numpy()  # 2,N

    # compute cond density for all pairs of u and some i, then of v and some i
    c1, c2 = 0, 0
    for c, j in enumerate([u, v]):  # range(i + 1, N):
        for i in range(N):
            if i == j:
                continue
            # most recent timestamp of either u or v
            t_bar = max_pairs[c, i]
            c2 += 1

            if t_bar < time_keys_min:
                start_ind = 0  # it means t_bar is beyond the history we kept, so use maximum period saved
            elif t_bar > time_keys_max:
                continue  # it means t_bar is current event, so there is no history for this pair of nodes
            else:
                # t_bar is somewhere in between time_keys_min and time_keys_min
                start_ind = time_keys.index(t_bar, start_ind_min)

            indices.append((c, i))
            l_indices.append(start_ind)

    indices = np.array(indices)
    l_indices = np.array(l_indices)
    s[indices[:, 0], indices[:, 1]] = Lambda_sum[l_indices]

    return s

In [29]:
# test the cond_density function
t_k = list(0.0 for i in range(4))
t_b = torch.tensor(time_bar)
cond_density(time_bar=t_b, u=0, v=1, N_nodes= N_nodes, Lambda_dict=Lambda_dict, time_keys=t_k)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [30]:
# forward pass
def forward(initial_embeddings, A_initial, data, N_nodes,psi, omega, n_assoc_types, node_degree_global, n_hidden, W_h, W_struct, W_rec, W_t, N_surv_samples=5, rnd= np.random.RandomState(111),opt=True):
    z,A,Lambda_dict, time_keys, t_p = initialize(node_embeddings=initial_embeddings, A_initial=A_initial,N_nodes=N_nodes, n_hidden=n_hidden,n_assoc_types=n_assoc_types, sparse=False, keepS=False)
    S = generate_S_from_A(A=A_initial,N_nodes=N_nodes,n_assoc_types=n_assoc_types)
     # opt is batch_update
    data[2] = data[2].float()
    data[4] = data[4].double()
    data[5] = data[5].double()
    u, v, k = data[0], data[1], data[3]
    time_delta_uv = data[2]
    time_bar = data[4]
    time_cur = data[5]
    event_types = k
    # u, v, time_delta_uv, event_types, time_bar, time_cur = data
    B = len(u)
    assert len(event_types) == B, (len(event_types), B)
    N = N_nodes

    A_pred, Surv = None, None
    A_pred = A.new_zeros(B, N, N).fill_(0)
    Surv = A.new_zeros(B, N, N).fill_(0)

    if opt:
        embeddings1, embeddings2, node_degrees = [], [], []
        embeddings_non1, embeddings_non2 = [], []
    else:
        lambda_uv_t, lambda_uv_t_non_events = [], []

    assert torch.min(time_delta_uv) >= 0, ('events must be in chronological order', torch.min(time_delta_uv))

    time_mn = torch.from_numpy(np.array([0, 0, 0, 0])).float().view(1, 1, 4)
    time_sd = torch.from_numpy(np.array([50, 7, 15, 15])).float().view(1, 1, 4)
    time_delta_uv = (time_delta_uv - time_mn) / time_sd

    reg = []

    S_batch = []

    z_all = []

    u_all = u.data.cpu().numpy()
    v_all = v.data.cpu().numpy()


    for it, k in enumerate(event_types):
        # k = 0: association event (rare)
        # k = 1,2,3: communication event (frequent)

        u_it, v_it = u_all[it], v_all[it]
        z_prev = z if it == 0 else z_all[it - 1]

        # 1. Compute intensity rate lambda based on node embeddings at previous time step (Eq. 1)
        if opt:
            # store node embeddings, compute lambda and S,A later based on the entire batch
            embeddings1.append(z_prev[u_it])
            embeddings2.append(z_prev[v_it])
        else:
            # accumulate intensity rate of events for this batch based on new embeddings
            lambda_uv_t.append(intensity_rate_lambda(z_prev[u_it], z_prev[v_it], torch.zeros(1).long() + k),n_hidden,psi,n_assoc_types,omega,edge_type=None)
            # intensity_rate_lambda(z_u, z_v, k,n_hidden,psi,n_assoc_types,omega,edge_type=None)


        # 2. Update node embeddings
        node_degree, z_new = update_node_embed(z_prev, u_it, v_it, time_delta_uv[it], n_hidden, n_assoc_types, S, A, W_h, W_struct, W_rec, W_t)  # / 3600.)  # hours
        # update_node_embed(prev_embed, node1, node2, time_delta_uv, n_hidden,n_assoc_types, S, A, W_h, W_struct, W_rec, W_t)
        if opt:
            node_degrees.append(node_degree)


        # 3. Update S and A
        if not opt:
            # we can update S and A based on current pair of nodes even during test time,
            # because S, A are not used in further steps for this iteration
            update_S_A(A, S, u_it, v_it, k.item(), node_degree, lambda_uv_t[it], N_nodes, n_assoc_types, node_degree_global)  #
            # update_S_A(A,S, u, v, k, node_degree, lambda_uv_t, N_nodes,n_assoc_types,node_degree_global)

        # update most recent degrees of nodes used to update S
        assert node_degree_global is not None
        for j in [u_it, v_it]:
            for rel in range(n_assoc_types):
                node_degree_global[rel][j] = node_degree[j][rel]


        # Non events loss
        # this is not important for test time, but we still compute these losses for debugging purposes
        # get random nodes except for u_it, v_it
        # 4. compute lambda for sampled events that do not happen -> to compute survival probability in loss
        uv_others = rnd.choice(np.delete(np.arange(N), [u_it, v_it]), size= N_surv_samples * 2, replace=False)
            # assert len(np.unique(uv_others)) == len(uv_others), ('nodes must be unique', uv_others)
        for q in range(N_surv_samples):
            assert u_it != uv_others[q], (u_it, uv_others[q])
            assert v_it != uv_others[N_surv_samples + q], (v_it, uv_others[N_surv_samples + q])
            if opt:
                embeddings_non1.extend([z_prev[u_it], z_prev[uv_others[N_surv_samples + q]]])
                embeddings_non2.extend([z_prev[uv_others[q]], z_prev[v_it]])
            else:
                for k_ in range(2):
                    lambda_uv_t_non_events.append(
                        intensity_rate_lambda(z_prev[u_it],
                                                    z_prev[uv_others[q]], torch.zeros(1).long() + k_, n_hidden,psi,n_assoc_types,omega))
                    lambda_uv_t_non_events.append(
                        intensity_rate_lambda(z_prev[uv_others[N_surv_samples + q]],
                                                    z_prev[v_it],
                                                    torch.zeros(1).long() + k_, n_hidden,psi,n_assoc_types,omega))


        # 5. compute conditional density for all possible pairs
        # here it's important NOT to use any information that the event between nodes u,v has happened
        # so, we use node embeddings of the previous time step: z_prev
        with torch.no_grad():
            z_cat = torch.cat((z_prev[u_it].detach().unsqueeze(0).expand(N, -1),
                                z_prev[v_it].detach().unsqueeze(0).expand(N, -1)), dim=0)
            Lambda = intensity_rate_lambda(z_cat, z_prev.detach().repeat(2, 1),
                                                torch.zeros(len(z_cat)).long() + k, n_hidden, psi, n_assoc_types, omega).detach()
            
            A_pred[it, u_it, :] = Lambda[:N]
            A_pred[it, v_it, :] = Lambda[N:]

            assert torch.sum(torch.isnan(A_pred[it])) == 0, (it, torch.sum(torch.isnan(A_pred[it])))
            # Compute the survival term (See page 3 in the paper)
            # we only need to compute the term for rows u_it and v_it in our matrix s to save time
            # because we will compute rank only for nodes u_it and v_it
            s1 = cond_density(time_bar[it], u_it, v_it,N_nodes, Lambda_dict, time_keys)
            # cond_density(time_bar, u, v, N_nodes, Lambda_dict, time_keys)
            Surv[it, [u_it, v_it], :] = s1

            time_key = int(time_cur[it].item())
            idx = np.delete(np.arange(N), [u_it, v_it])  # nonevents for node u
            idx = np.concatenate((idx, idx + N))   # concat with nonevents for node v

            if len(time_keys) >= len(Lambda_dict):
                # shift in time (remove the oldest record)
                time_keys = np.array(time_keys)
                time_keys[:-1] = time_keys[1:]
                time_keys = list(time_keys[:-1])  # remove last
                Lambda_dict[:-1] = Lambda_dict.clone()[1:]
                Lambda_dict[-1] = 0

            Lambda_dict[len(time_keys)] = Lambda[idx].sum().detach()  # total intensity of non events for the current time step
            time_keys.append(time_key)

        # Once we made predictions for the training and test sample, we can update node embeddings
        z_all.append(z_new)
        # update S

        A = S
        S_batch.append(S.data.cpu().numpy())

        t_p += 1

    z = z_new  # update node embeddings

    # Batch update
    if opt:
        lambda_uv_t = intensity_rate_lambda(torch.stack(embeddings1, dim=0),
                                                    torch.stack(embeddings2, dim=0), event_types, n_hidden, psi, n_assoc_types, omega)
        non_events = len(embeddings_non1)
        n_types = 2
        lambda_uv_t_non_events = torch.zeros(non_events * n_types)
        embeddings_non1 = torch.stack(embeddings_non1, dim=0)
        embeddings_non2 = torch.stack(embeddings_non2, dim=0)
        idx = None
        empty_t = torch.zeros(non_events, dtype=torch.long)
        types_lst = torch.arange(n_types)
        for k in types_lst:
            if idx is None:
                idx = np.arange(non_events)
            else:
                idx += non_events
            lambda_uv_t_non_events[idx] = intensity_rate_lambda(embeddings_non1, embeddings_non2, empty_t + k, n_hidden, psi, n_assoc_types, omega)

        # update only once per batch
        for it, k in enumerate(event_types):
            u_it, v_it = u_all[it], v_all[it]
            update_S_A(A, S, u_it, v_it, k.item(), node_degrees[it], lambda_uv_t[it].item(), N_nodes, n_assoc_types, node_degree_global)
            # def update_S_A(A,S, u, v, k, node_degree, lambda_uv_t, N_nodes,n_assoc_types,node_degree_global)

    else:
        lambda_uv_t = torch.cat(lambda_uv_t)
        lambda_uv_t_non_events = torch.cat(lambda_uv_t_non_events)


    if len(reg) > 1:
        reg = [torch.stack(reg).mean()]

    return lambda_uv_t, lambda_uv_t_non_events / N_surv_samples, [A_pred, Surv], reg

In [31]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=200, shuffle=False)
test_loader = DataLoader(test_set, batch_size=200, shuffle=False)

In [32]:
import datetime
from datetime import datetime,timezone
for batch_idx, data in enumerate(test_loader):
    lambda_uv_t, lambda_uv_t_non_events, [A_pred, Surv], reg = forward(initial_embeddings=initial_embeddings,A_initial=A_initial,data=data,N_nodes=N_nodes,psi=psi,omega=omega,n_assoc_types=n_assoc_types,node_degree_global=node_degree_global,n_hidden=n_hidden,W_h=W_h, W_struct=W_struct, W_rec=W_rec, W_t=W_t, N_surv_samples=5,rnd= np.random.RandomState(111),opt=True)
    print('lambda_uv_t', lambda_uv_t)
    print('lambda_uv_t_non_events', lambda_uv_t_non_events)
    print('A_pred', A_pred)
    print('Surv', Surv)
    print('reg', reg)
# forward(initial_embeddings, A_initial, data, N_nodes,psi, omega, n_assoc_types, node_degree_global, n_hidden, W_h, W_struct, W_rec, W_t, N_surv_samples=5, rnd= np.random.RandomState(111),opt=True)

initialize models node embeddings and adjacency matrices for 84 nodes
A_initial (84, 84, 1)


lambda_uv_t tensor([0.5457, 0.5983, 0.3689, 0.4311, 0.5656, 0.6209, 0.3981, 0.5467, 0.6049,
        0.5794, 0.7153, 0.5885, 0.5452, 0.5675, 0.4389, 0.5737, 0.5833, 0.5678,
        0.5745, 0.5815, 0.5699, 0.5749, 0.5572, 0.5722, 0.5591, 0.5725, 0.5596,
        0.5725, 0.5601, 0.5726, 0.5608, 0.6564, 0.6170, 0.6659, 0.6163, 0.6648,
        0.6170, 0.6637, 0.6176, 0.6627, 0.6183, 0.6617, 0.6191, 0.6606, 0.6185,
        0.6605, 0.6211, 0.6587, 0.6200, 0.6584, 0.6227, 0.6567, 0.6202, 0.6573,
        0.4869, 0.6703, 0.6360, 0.6493, 0.6304, 0.6549, 0.6300, 0.6552, 0.6292,
        0.5812, 0.6186, 0.6176, 0.5657, 0.5709, 0.5591, 0.6348, 0.6223, 0.5692,
        0.5700, 0.5590, 0.3987, 0.5714, 0.5663, 0.7269, 0.5770, 0.5654, 0.5789,
        0.5764, 0.5650, 0.6115, 0.6298, 0.5959, 0.5913, 0.5794, 0.5653, 0.6547,
        0.5761, 0.5660, 0.5759, 0.6214, 0.5605, 0.3287, 0.6261, 0.5566, 0.5922,
        0.5913, 0.6330, 0.7431, 0.6014, 0.5840, 0.5790, 0.5668, 0.5855, 0.5757,
        0.5665, 0.5856, 0.57