In [1]:
import numpy as np
import datetime
import torch
import torch.utils
from datetime import timezone


class EventsDataset(torch.utils.data.Dataset):
    '''
    Base class for event datasets
    '''
    def __init__(self, TZ=None):
        self.TZ = TZ  # timezone.utc

        # Implement here these fields (see examples in actual datasets):
        # self.FIRST_DATE = datetime()
        # self.TEST_TIMESLOTS = []
        # self.N_nodes = 100
        # self.A_initial = np.random.randint(0, 2, size=(self.N_nodes, self.N_nodes))
        # self.A_last = np.random.randint(0, 2, size=(self.N_nodes, self.N_nodes))
        #
        # self.all_events = []
        # self.n_events = len(self.all_events)
        #
        # self.event_types = ['communication event']
        # self.event_types_num = {'association event': 0}
        # k = 1  # k >= 1 for communication events
        # for t in self.event_types:
        #     self.event_types_num[t] = k
        #     k += 1


    def get_Adjacency(self, multirelations=False):
        return None, None, None

    def __len__(self):
        return self.n_events

    def __getitem__(self, index):

        tpl = self.all_events[index]
        u, v, rel, time_cur = tpl

        # Compute time delta in seconds (t_p - \bar{t}_p_j) that will be fed to W_t
        time_delta_uv = np.zeros((2, 4))  # two nodes x 4 values

        # most recent previous time for all nodes
        time_bar = self.time_bar.copy()
        assert u != v, (tpl, rel)

        u = int(u)
        v = int(v)
        
        for c, j in enumerate([u, v]):
            t = datetime.datetime.fromtimestamp(self.time_bar[j][0], tz=self.TZ)
            if t.toordinal() >= self.FIRST_DATE.toordinal():  # assume no events before FIRST_DATE
                td = time_cur - t
                time_delta_uv[c] = np.array([td.days,  # total number of days, still can be a big number
                                             td.seconds // 3600,  # hours, max 24
                                             (td.seconds // 60) % 60,  # minutes, max 60
                                             td.seconds % 60],  # seconds, max 60
                                            np.float64)
                # assert time_delta_uv.min() >= 0, (index, tpl, time_delta_uv[c], node_global_time[j])
            else:
                raise ValueError('unexpected result', t, self.FIRST_DATE)
            self.time_bar[j] = time_cur.timestamp()  # last time stamp for nodes u and v

        k = self.event_types_num[rel]

        # sanity checks
        assert np.float64(time_cur.timestamp()) == time_cur.timestamp(), (
        np.float64(time_cur.timestamp()), time_cur.timestamp())
        time_cur = np.float64(time_cur.timestamp())
        time_bar = time_bar.astype(np.float64)
        time_cur = torch.from_numpy(np.array([time_cur])).double()
        assert time_bar.max() <= time_cur, (time_bar.max(), time_cur)
        return u, v, time_delta_uv, k, time_bar, time_cur

In [2]:
import os
import numpy as np
import datetime
import pickle
from datetime import timezone
import dateutil.parser


def iso_parse(dt):
    # return datetime.fromisoformat(dt)  # python >= 3.7
    return dateutil.parser.isoparse(dt)

class GithubDataset(EventsDataset):

    def __init__(self, split, data_dir='./Github'):
        super(GithubDataset, self).__init__()

        if split == 'train':
            time_start = 0
            time_end = datetime.datetime(2013, 8, 31, tzinfo=self.TZ).toordinal()
        elif split == 'test':
            time_start = datetime.datetime(2013, 9, 1, tzinfo=self.TZ).toordinal()
            time_end = datetime.datetime(2014, 1, 1, tzinfo=self.TZ).toordinal()
        else:
            raise ValueError('invalid split', split)

        self.FIRST_DATE = datetime.datetime(2012, 12, 28, tzinfo=self.TZ)

        self.TEST_TIMESLOTS = [datetime.datetime(2013, 9, 1, tzinfo=self.TZ),
                               datetime.datetime(2013, 9, 25, tzinfo=self.TZ),
                               datetime.datetime(2013, 10, 20, tzinfo=self.TZ),
                               datetime.datetime(2013, 11, 15, tzinfo=self.TZ),
                               datetime.datetime(2013, 12, 10, tzinfo=self.TZ),
                               datetime.datetime(2014, 1, 1, tzinfo=self.TZ)]



        with open(os.path.join(data_dir, 'github_284users_events_2013.pkl'), 'rb') as f:
            users_events, event_types = pickle.load(f)

        with open(os.path.join(data_dir, 'github_284users_follow_2011_2012.pkl'), 'rb') as f:
            users_follow = pickle.load(f)

        print(event_types)

        self.events2name = {}
        for e in event_types:
            self.events2name[event_types[e]] = e
        print(self.events2name)

        self.event_types = ['ForkEvent', 'PushEvent', 'WatchEvent', 'IssuesEvent', 'IssueCommentEvent',
                           'PullRequestEvent', 'CommitCommentEvent']
        self.assoc_types = ['FollowEvent']
        self.is_comm = lambda d: self.events2name[d['type']] in self.event_types
        self.is_assoc = lambda d: self.events2name[d['type']] in self.assoc_types

        user_ids = {}
        for id, user in enumerate(sorted(users_events)):
            user_ids[user] = id

        self.N_nodes = len(user_ids)

        self.A_initial = np.zeros((self.N_nodes, self.N_nodes))
        for user in users_follow:
            for e in users_follow[user]:
                assert e['type'] in self.assoc_types, e['type']
                if e['login'] in users_events:
                    self.A_initial[user_ids[user], user_ids[e['login']]] = 1

        self.A_last = np.zeros((self.N_nodes, self.N_nodes))
        for user in users_events:
            for e in users_events[user]:
                if self.events2name[e['type']] in self.assoc_types:
                    self.A_last[user_ids[user], user_ids[e['login']]] = 1
        self.time_bar = np.full(self.N_nodes, self.FIRST_DATE.timestamp())


        print('\nA_initial', np.sum(self.A_initial))
        print('A_last', np.sum(self.A_last), '\n')

        all_events = []
        for user in users_events:
            if user not in user_ids:
                continue
            user_id = user_ids[user]
            for ind, event in enumerate(users_events[user]):
                event['created_at'] = datetime.datetime.fromtimestamp(event['created_at'])
                if event['created_at'].toordinal() >= time_start and event['created_at'].toordinal() <= time_end:
                    if 'owner' in event:
                        if event['owner'] not in user_ids:
                            continue
                        user_id2 = user_ids[event['owner']]
                    elif 'login' in event:
                        if event['login'] not in user_ids:
                            continue
                        user_id2 = user_ids[event['login']]
                    else:
                        raise ValueError('invalid event', event)
                    if user_id != user_id2:
                        all_events.append((user_id, user_id2,
                                           self.events2name[event['type']], event['created_at']))

        self.all_events = sorted(all_events, key=lambda t: t[3].timestamp())
        print('\n%s' % split.upper())
        print('%d events between %d users loaded' % (len(self.all_events), self.N_nodes))
        print('%d communication events' % (len([t for t in self.all_events if t[2] == 1])))
        print('%d assocition events' % (len([t for t in self.all_events if t[2] == 0])))

        self.event_types_num = {self.assoc_types[0]: 0}
        k = 1  # k >= 1 for communication events
        for t in self.event_types:
            self.event_types_num[t] = k
            k += 1

        self.n_events = len(self.all_events)


    def get_Adjacency(self, multirelations=False):
        if multirelations:
            print('warning: Github has only one relation type (FollowEvent), so multirelations are ignored')
        return self.A_initial, self.assoc_types, self.A_last

In [3]:
datdir = '/Users/amberrrrrr/Desktop/huozhe/Github'
train_set = GithubDataset('train', data_dir=datdir)
test_set = GithubDataset('test', data_dir=datdir)

# Preview the first few lines of the train set and test set
print("Train set preview (first 5 events):")
for event in train_set.all_events[:5]:
    print(event)

print("\nTest set preview (first 5 events):")
for event in test_set.all_events[:5]:
    print(event)

{'PushEvent': 0, 'WatchEvent': 1, 'ForkEvent': 2, 'IssuesEvent': 3, 'FollowEvent': 4, 'PullRequestEvent': 5, 'IssueCommentEvent': 6, 'CommitCommentEvent': 7}
{0: 'PushEvent', 1: 'WatchEvent', 2: 'ForkEvent', 3: 'IssuesEvent', 4: 'FollowEvent', 5: 'PullRequestEvent', 6: 'IssueCommentEvent', 7: 'CommitCommentEvent'}

A_initial 298.0
A_last 1420.0 


TRAIN
11627 events between 284 users loaded
0 communication events
0 assocition events
{'PushEvent': 0, 'WatchEvent': 1, 'ForkEvent': 2, 'IssuesEvent': 3, 'FollowEvent': 4, 'PullRequestEvent': 5, 'IssueCommentEvent': 6, 'CommitCommentEvent': 7}
{0: 'PushEvent', 1: 'WatchEvent', 2: 'ForkEvent', 3: 'IssuesEvent', 4: 'FollowEvent', 5: 'PullRequestEvent', 6: 'IssueCommentEvent', 7: 'CommitCommentEvent'}

A_initial 298.0
A_last 1420.0 


TEST
9099 events between 284 users loaded
0 communication events
0 assocition events
Train set preview (first 5 events):
(74, 6, 'PushEvent', datetime.datetime(2013, 1, 1, 8, 53, 4))
(103, 268, 'IssueCommentEvent'

In [4]:
initial_embeddings = np.random.randn(train_set.N_nodes, 32)
A_initial = train_set.get_Adjacency()[0]

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class DyRep(nn.Module):
    def __init__(self,
                 node_embeddings,
                 A_initial=None,
                 N_surv_samples=5,
                 n_hidden=32,
                 N_hops=2,
                 sparse=False,
                 node_degree_global=None,
                 rnd=np.random.RandomState(111)):
        super(DyRep, self).__init__()
    
        # initialisations
        self.opt = True
        self.exp = True
        self.rnd = rnd
        self.n_hidden = n_hidden
        self.sparse = sparse
        self.N_surv_samples = N_surv_samples
        self.node_degree_global = node_degree_global
        self.N_nodes = A_initial.shape[0]
        if A_initial is not None and len(A_initial.shape) == 2:
            A_initial = A_initial[:, :, None]
        self.n_assoc_types = 1
        self.N_hops = N_hops

        self.initialize(node_embeddings, A_initial)
        self.W_h = nn.Linear(in_features=n_hidden, out_features=n_hidden)
        self.W_struct = nn.Linear(n_hidden * self.n_assoc_types, n_hidden)
        self.W_rec = nn.Linear(n_hidden, n_hidden)
        self.W_t = nn.Linear(4, n_hidden)

        n_types = 2  # associative and communicative
        d1 = self.n_hidden + (0)
        d2 = self.n_hidden + (0)

        d1 += self.n_hidden
        d2 += self.n_hidden
        self.omega = nn.ModuleList([nn.Linear(d1, 1), nn.Linear(d2, 1)])

        self.psi = nn.Parameter(0.5 * torch.ones(n_types)) 

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # print('before Xavier', m.weight.data.shape, m.weight.data.min(), m.weight.data.max())
                nn.init.xavier_normal_(m.weight.data)

    def generate_S_from_A(self):
        if isinstance(self.A, np.ndarray):
            self.A = torch.tensor(self.A, dtype=torch.float32)  # Convert A to a tensor if it's a numpy array
        S = self.A.new_empty(self.N_nodes, self.N_nodes, self.n_assoc_types).fill_(0)
        for rel in range(self.n_assoc_types):
            D = torch.sum(self.A[:, :, rel], dim=1).float()
            for i, v in enumerate(torch.nonzero(D, as_tuple=False).squeeze()):
                u = torch.nonzero(self.A[v, :, rel].squeeze(), as_tuple=False).squeeze()
                S[v, u, rel] = 1. / D[v]
        self.S = S
        # Check that values in each row of S add up to 1
        for rel in range(self.n_assoc_types):
            S = self.S[:, :, rel]
            assert torch.sum(S[self.A[:, :, rel] == 0]) < 1e-5, torch.sum(S[self.A[:, :, rel] == 0])

    def initialize(self,node_embeddings, A_initial,keepS=False):
        print('initialize model''s node embeddings and adjacency matrices for %d nodes' % self.N_nodes)
        # Initial embeddings
        if node_embeddings is not None:
            z = np.pad(node_embeddings, ((0, 0), (0, self.n_hidden - node_embeddings.shape[1])), 'constant')
            z = torch.from_numpy(z).float()

        if A_initial is None:
            print('initial random prediction of A')
            A = torch.zeros(self.N_nodes, self.N_nodes, self.n_assoc_types + int(self.sparse))

            for i in range(self.N_nodes):
                for j in range(i + 1, self.N_nodes):
                    if self.sparse:
                        if self.n_assoc_types == 1:
                            pvals = [0.95, 0.05]
                        elif self.n_assoc_types == 2:
                            pvals = [0.9, 0.05, 0.05]
                        elif self.n_assoc_types == 3:
                            pvals = [0.91, 0.03, 0.03, 0.03]
                        elif self.n_assoc_types == 4:
                            pvals = [0.9, 0.025, 0.025, 0.025, 0.025]
                        else:
                            raise NotImplementedError(self.n_assoc_types)
                        ind = np.nonzero(np.random.multinomial(1, pvals))[0][0]
                    else:
                        ind = np.random.randint(0, self.n_assoc_types, size=1)
                    A[i, j, ind] = 1
                    A[j, i, ind] = 1
            assert torch.sum(torch.isnan(A)) == 0, (torch.sum(torch.isnan(A)), A)
            if self.sparse:
                A = A[:, :, 1:]

        else:
            print('A_initial', A_initial.shape)
            A = torch.from_numpy(A_initial).float()
            if len(A.shape) == 2:
                A = A.unsqueeze(2)

        # make these variables part of the model
        self.register_buffer('z', z)
        self.register_buffer('A', A)

        self.A = A  
        if not keepS:
            self.generate_S_from_A()

        self.Lambda_dict = torch.zeros(5000)
        self.time_keys = []

        self.t_p = 0  # global counter of iterations
    
    def check_S(self):
        for rel in range(self.n_assoc_types):
            rows = torch.nonzero(torch.sum(self.A[:, :, rel], dim=1).float())
            # check that the sum in all rows equal 1
            assert torch.all(torch.abs(torch.sum(self.S[:, :, rel], dim=1)[rows] - 1) < 1e-1), torch.abs(torch.sum(self.S[:, :, rel], dim=1)[rows] - 1)

    
    def g_fn(self,z_cat, k, edge_type=None, z2=None):
        if z2 is not None:
            z_cat = torch.cat((z_cat, z2), dim=1)
        else:
            raise NotImplementedError('')
        g = z_cat.new(len(z_cat), 1).fill_(0)
        idx = k <= 0
        if torch.sum(idx) > 0:
            if edge_type is not None:
                z_cat1 = torch.cat((z_cat[idx], edge_type[idx, :self.n_assoc_types]), dim=1)
            else:
                z_cat1 = z_cat[idx]
            g[idx] = self.omega[0](z_cat1)
        idx = k > 0
        if torch.sum(idx) > 0:
            if edge_type is not None:
                z_cat1 = torch.cat((z_cat[idx], edge_type[idx, self.n_assoc_types:]), dim=1)
            else:
                z_cat1 = z_cat[idx]
            g[idx] = self.omega[1](z_cat1)

        g = g.flatten()
        return g
    
    def intensity_rate_lambda(self,z_u, z_v, k):
        z_u = z_u.view(-1, self.n_hidden).contiguous()
        z_v = z_v.view(-1, self.n_hidden).contiguous()
        edge_type = None
        g = 0.5 * (self.g_fn(z_u, (k > 0).long(), edge_type=edge_type, z2=z_v) + self.g_fn(z_v, (k > 0).long(),edge_type=edge_type, z2=z_u))  # make it symmetric, because most events are symmetric
        psi = self.psi[(k > 0).long()]
        g_psi = torch.clamp(g / (psi + 1e-7), -75, 75)  # to prevent overflow
        Lambda = psi * (torch.log(1 + torch.exp(-g_psi)) + g_psi)
        return Lambda
    
    def update_node_embed(self,prev_embed, node1, node2, time_delta_uv):
        # z contains all node embeddings of previous time \bar{t}
        # S also corresponds to previous time stamp, because it's not updated yet based on this event

        node_embed = prev_embed

        node_degree = {} # we need degrees to update S
        z_new = prev_embed.clone()  # to allow in place changes while keeping gradients
        
        #precompute the N-hop neighbors
        A_float = self.A.squeeze(-1).float()
        A_power = torch.eye(A_float.shape[0])
        extended_neighbors = [A_power.clone()]
        for _ in range(self.N_hops):
            A_power = torch.mm(A_power, A_float)
            extended_neighbors.append((A_power>0).clone())
        h_u_struct = prev_embed.new_zeros((2, self.n_hidden, self.n_assoc_types))
        for c, (v, u, delta_t) in enumerate(zip([node1, node2], [node2, node1], time_delta_uv)):  # i is the other node involved in the event
            node_degree[u] = np.zeros(self.n_assoc_types)
            for rel in range(self.n_assoc_types):
                Neighb_u = torch.zeros(self.A.shape[1],dtype=torch.bool)
                for i in range(1,self.N_hops+1):
                    Neighb_u |=  extended_neighbors[i][u,:] >0
                
                N_neighb = torch.sum(Neighb_u).item()  # number of neighbors for node u
                node_degree[u][rel] = N_neighb
                if N_neighb > 0:  # node has no neighbors
                    h_prev_i = self.W_h(node_embed[Neighb_u]).view(N_neighb, self.n_hidden)
                    # attention over neighbors
                    q_ui = torch.exp(self.S[u, Neighb_u, rel]).view(N_neighb, 1)
                    q_ui = q_ui / (torch.sum(q_ui) + 1e-7)
                    h_u_struct[c, :, rel] = torch.max(torch.sigmoid(q_ui * h_prev_i), dim=0)[0].view(1, self.n_hidden)

        h1 = self.W_struct(h_u_struct.view(2, self.n_hidden * self.n_assoc_types))

        h2 = self.W_rec(node_embed[[node1, node2], :].view(2, -1))
        h3 = self.W_t(time_delta_uv.float()).view(2, self.n_hidden)

        z_new[[node1, node2], :] = torch.sigmoid(h1 + h2 + h3)
        return node_degree, z_new
    
    def update_S_A(self, u, v, k, node_degree, lambda_uv_t):
        if k <= 0 :  # Association event
            # do not update in case of latent graph
            self.A[u, v, np.abs(k)] = self.A[v, u, np.abs(k)] = 1  # 0 for CloseFriends, k = -1 for the second relation, so it's abs(k) matrix in self.A
        A = self.A
        indices = torch.arange(self.N_nodes)
        for rel in range(self.n_assoc_types):
            if k > 0 and A[u, v, rel] == 0:  # Communication event, no Association exists
                continue  # do not update S and A
            else:
                for j, i in zip([u, v], [v, u]):
                    # i is the "other node involved in the event"
                    try:
                        degree = node_degree[j]
                    except:
                        print(list(node_degree.keys()))
                        raise
                    y = self.S[j, :, rel]
                    # assert torch.sum(torch.isnan(y)) == 0, ('b', j, degree[rel], node_degree_global[rel][j.item()], y)
                    b = 0 if degree[rel] == 0 else 1. / (float(degree[rel]) + 1e-7)
                    if k > 0 and A[u, v, rel] > 0:  # Communication event, Association exists
                        y[i] = b + lambda_uv_t
                    elif k <= 0 and A[u, v, rel] > 0:  # Association event
                        if self.node_degree_global[rel][j] == 0:
                            b_prime = 0
                        else:
                            b_prime = 1. / (float(self.node_degree_global[rel][j]) + 1e-7)
                        x = b_prime - b
                        y[i] = b + lambda_uv_t
                        w = (y != 0) & (indices != int(i))
                        y[w] = y[w] - x
                    y /= (torch.sum(y) + 1e-7)  # normalize
                    self.S[j, :, rel] = y
        return 
    
    # conditional density calculation to predict the next event (the probability of the next event for each pair of nodes)
    def cond_density(self,time_bar,u, v):
        N = self.N_nodes
        if not self.time_keys:  # Checks if time_keys is empty
            print("Warning: time_keys is empty. No operations performed.")
            return torch.zeros((2, self.N_nodes)) 
        s = self.Lambda_dict.new_zeros((2, N))
        #normalize lambda values by dividing by the number of events
        Lambda_sum = torch.cumsum(self.Lambda_dict.flip(0), 0).flip(0)  / len(self.Lambda_dict)
        time_keys_min = self.time_keys[0]
        time_keys_max = self.time_keys[-1]

        indices = []
        l_indices = []
        t_bar_min = torch.min(time_bar[[u, v]]).item()
        if t_bar_min < time_keys_min:
            start_ind_min = 0
        elif t_bar_min > time_keys_max:
            # it means t_bar will always be larger, so there is no history for these nodes
            return s
        else:
            start_ind_min = self.time_keys.index(int(t_bar_min))

        # print("time_bar shape:", time_bar.shape)
        # print("Expanded and reshaped time_bar shape:", time_bar[[u, v]].view(1, 2).expand(N, -1).t().contiguous().view(2 * N, 1).shape)
        # print("Repeated time_bar shape:", time_bar.repeat(2, 1).shape)
        # Reshape expanded and reshaped time_bar
        expanded_time_bar = time_bar[[u, v]].view(1, 2).expand(N, -1).t().contiguous().view(2 * N, 1)
        # Adjust repeated time_bar to match the expanded shape
        adjusted_repeated_time_bar = time_bar.repeat(2, 1).view(2 * N, 1)
        # Now concatenate along dimension 1 (should work as both tensors are (168, 1))
        max_pairs = torch.max(torch.cat((expanded_time_bar, adjusted_repeated_time_bar), dim=1), dim=1)[0].view(2, N).long()
        # max_pairs = torch.max(torch.cat((time_bar[[u, v]].view(1, 2).expand(N, -1).t().contiguous().view(2 * N, 1),
        #                                     time_bar.repeat(2, 1)), dim=1), dim=1)[0].view(2, N).long().data.cpu().numpy()  # 2,N

        # compute cond density for all pairs of u and some i, then of v and some i
        c1, c2 = 0, 0
        for c, j in enumerate([u, v]):  # range(i + 1, N):
            for i in range(N):
                if i == j:
                    continue
                # most recent timestamp of either u or v
                t_bar = max_pairs[c, i]
                c2 += 1

                if t_bar < time_keys_min:
                    start_ind = 0  # it means t_bar is beyond the history we kept, so use maximum period saved
                elif t_bar > time_keys_max:
                    continue  # it means t_bar is current event, so there is no history for this pair of nodes
                else:
                    # t_bar is somewhere in between time_keys_min and time_keys_min
                    start_ind = self.time_keys.index(t_bar, start_ind_min)

                indices.append((c, i))
                l_indices.append(start_ind)

        indices = np.array(indices)
        l_indices = np.array(l_indices)
        s[indices[:, 0], indices[:, 1]] = Lambda_sum[l_indices]

        return s
    
    # forward pass
    def forward(self,data):
        # opt is batch_update
        data[2] = data[2].float()
        data[4] = data[4].double()
        data[5] = data[5].double()
        u, v, k = data[0], data[1], data[3]
        time_delta_uv = data[2]
        time_bar = data[4]
        time_cur = data[5]
        event_types = k
        # u, v, time_delta_uv, event_types, time_bar, time_cur = data
        B = len(u)
        assert len(event_types) == B, (len(event_types), B)
        N = self.N_nodes

        A_pred, Surv = None, None
        A_pred = self.A.new_zeros(B, N, N).fill_(0)
        Surv = self.A.new_zeros(B, N, N).fill_(0)

        if self.opt:
            embeddings1, embeddings2, node_degrees = [], [], []
            embeddings_non1, embeddings_non2 = [], []
        else:
            lambda_uv_t, lambda_uv_t_non_events = [], []

        assert torch.min(time_delta_uv) >= 0, ('events must be in chronological order', torch.min(time_delta_uv))

        time_mn = torch.from_numpy(np.array([0, 0, 0, 0])).float().view(1, 1, 4)
        time_sd = torch.from_numpy(np.array([50, 7, 15, 15])).float().view(1, 1, 4)
        time_delta_uv = (time_delta_uv - time_mn) / time_sd

        reg = []

        S_batch = []

        z_all = []

        u_all = u.data.cpu().numpy()
        v_all = v.data.cpu().numpy()


        for it, k in enumerate(event_types):
            # k = 0: association event (rare)
            # k = 1,2,3: communication event (frequent)

            u_it, v_it = u_all[it], v_all[it]
            z_prev = self.z if it == 0 else z_all[it - 1]

            # 1. Compute intensity rate lambda based on node embeddings at previous time step (Eq. 1)
            if self.opt:
                # store node embeddings, compute lambda and S,A later based on the entire batch
                embeddings1.append(z_prev[u_it])
                embeddings2.append(z_prev[v_it])
            else:
                # accumulate intensity rate of events for this batch based on new embeddings
                lambda_uv_t.append(self.intensity_rate_lambda(z_prev[u_it], z_prev[v_it], torch.zeros(1).long() + k))
                # intensity_rate_lambda(z_u, z_v, k,n_hidden,psi,n_assoc_types,omega,edge_type=None)


            # 2. Update node embeddings
            node_degree, z_new = self.update_node_embed(z_prev, u_it, v_it, time_delta_uv[it])  # / 3600.)  # hours
            # update_node_embed(prev_embed, node1, node2, time_delta_uv, n_hidden,n_assoc_types, S, A, W_h, W_struct, W_rec, W_t)
            if self.opt:
                node_degrees.append(node_degree)


            # 3. Update S and A
            if not self.opt:
                # we can update S and A based on current pair of nodes even during test time,
                # because S, A are not used in further steps for this iteration
                self.update_S_A(u_it, v_it, k.item(), node_degree, lambda_uv_t[it])  #
                # update_S_A(A,S, u,v, k, node_degree, lambda_uv_t, N_nodes,n_assoc_types,node_degree_global)

            # update most recent degrees of nodes used to update S
            assert self.node_degree_global is not None
            for j in [u_it, v_it]:
                for rel in range(self.n_assoc_types):
                    self.node_degree_global[rel][j] = node_degree[j][rel]


            # Non events loss
            # this is not important for test time, but we still compute these losses for debugging purposes
            # get random nodes except for u_it, v_it
            # 4. compute lambda for sampled events that do not happen -> to compute survival probability in loss
            uv_others = self.rnd.choice(np.delete(np.arange(N), [u_it, v_it]), size= self.N_surv_samples * 2, replace=False)
                # assert len(np.unique(uv_others)) == len(uv_others), ('nodes must be unique', uv_others)
            for q in range(self.N_surv_samples):
                assert u_it != uv_others[q], (u_it, uv_others[q])
                assert v_it != uv_others[self.N_surv_samples + q], (v_it, uv_others[self.N_surv_samples + q])
                if self.opt:
                    embeddings_non1.extend([z_prev[u_it], z_prev[uv_others[self.N_surv_samples + q]]])
                    embeddings_non2.extend([z_prev[uv_others[q]], z_prev[v_it]])
                else:
                    for k_ in range(2):
                        lambda_uv_t_non_events.append(
                            self.intensity_rate_lambda(z_prev[u_it],
                                                        z_prev[uv_others[q]], torch.zeros(1).long() + k_))
                        lambda_uv_t_non_events.append(
                            self.intensity_rate_lambda(z_prev[uv_others[self.N_surv_samples + q]],
                                                        z_prev[v_it],
                                                        torch.zeros(1).long() + k_))


            # 5. compute conditional density for all possible pairs
            # here it's important NOT to use any information that the event between nodes u,v has happened
            # so, we use node embeddings of the previous time step: z_prev
            with torch.no_grad():
                z_cat = torch.cat((z_prev[u_it].detach().unsqueeze(0).expand(N, -1),
                                    z_prev[v_it].detach().unsqueeze(0).expand(N, -1)), dim=0)
                Lambda = self.intensity_rate_lambda(z_cat, z_prev.detach().repeat(2, 1),
                                                    torch.zeros(len(z_cat)).long() + k).detach()
                
                A_pred[it, u_it, :] = Lambda[:N]
                A_pred[it, v_it, :] = Lambda[N:]

                assert torch.sum(torch.isnan(A_pred[it])) == 0, (it, torch.sum(torch.isnan(A_pred[it])))
                # Compute the survival term (See page 3 in the paper)
                # we only need to compute the term for rows u_it and v_it in our matrix s to save time
                # because we will compute rank only for nodes u_it and v_it
                s1 = self.cond_density(time_bar[it], u_it, v_it)
                # cond_density(time_bar, u, v, N_nodes, Lambda_dict, time_keys)
                Surv[it, [u_it, v_it], :] = s1

                time_key = int(time_cur[it].item())
                idx = np.delete(np.arange(N), [u_it, v_it])  # nonevents for node u
                idx = np.concatenate((idx, idx + N))   # concat with nonevents for node v

                if len(self.time_keys) >= len(self.Lambda_dict):
                    # shift in time (remove the oldest record)
                    time_keys = np.array(self.time_keys)
                    time_keys[:-1] = time_keys[1:]
                    self.time_keys = list(time_keys[:-1])  # remove last
                    self.Lambda_dict[:-1] = self.Lambda_dict.clone()[1:]
                    self.Lambda_dict[-1] = 0

                self.Lambda_dict[len(self.time_keys)] = Lambda[idx].sum().detach()  # total intensity of non events for the current time step
                self.time_keys.append(time_key)

            # Once we made predictions for the training and test sample, we can update node embeddings
            z_all.append(z_new)
            # update S

            self.A = self.S
            S_batch.append(self.S.data.cpu().numpy())

            self.t_p += 1

        self.z = z_new  # update node embeddings

        # Batch update
        if self.opt:
            lambda_uv_t = self.intensity_rate_lambda(torch.stack(embeddings1, dim=0),
                                                        torch.stack(embeddings2, dim=0), event_types)
            non_events = len(embeddings_non1)
            n_types = 2
            lambda_uv_t_non_events = torch.zeros(non_events * n_types)
            embeddings_non1 = torch.stack(embeddings_non1, dim=0)
            embeddings_non2 = torch.stack(embeddings_non2, dim=0)
            idx = None
            empty_t = torch.zeros(non_events, dtype=torch.long)
            types_lst = torch.arange(n_types)
            for k in types_lst:
                if idx is None:
                    idx = np.arange(non_events)
                else:
                    idx += non_events
                lambda_uv_t_non_events[idx] = self.intensity_rate_lambda(embeddings_non1, embeddings_non2, empty_t + k)

            # update only once per batch
            for it, k in enumerate(event_types):
                u_it, v_it = u_all[it], v_all[it]
                self.update_S_A(u_it, v_it, k.item(), node_degrees[it], lambda_uv_t[it].item())
                # def update_S_A(A,S, u, v, k, node_degree, lambda_uv_t, N_nodes,n_assoc_types,node_degree_global)

        else:
            lambda_uv_t = torch.cat(lambda_uv_t)
            lambda_uv_t_non_events = torch.cat(lambda_uv_t_non_events)


        if len(reg) > 1:
            reg = [torch.stack(reg).mean()]

        return lambda_uv_t, lambda_uv_t_non_events / self.N_surv_samples, [A_pred, Surv], reg
 

In [6]:
import torch
from torch.utils.data import DataLoader

#initialise the A and z 
N_nodes = A_initial.shape[0]
if A_initial is not None and len(A_initial.shape) == 2:
    A_initial = A_initial[:, :, None]
n_assoc_types,n_event_types = 1, 3
n_relations = n_assoc_types + n_event_types

Adj_all = train_set.get_Adjacency()[0]

if not isinstance(Adj_all, list):
    Adj_all = [Adj_all]

node_degree_global = []
for rel, A in enumerate(Adj_all):
    node_degree_global.append(np.zeros(A.shape[0]))
    for u in range(A.shape[0]):
        node_degree_global[rel][u] = np.sum(A[u])

Adj_all = Adj_all[0]

# Instantiate the model
model = DyRep(
    node_embeddings=initial_embeddings,
    A_initial=A_initial,
    n_hidden=32,
    node_degree_global=node_degree_global,
    N_hops=2
)

initialize models node embeddings and adjacency matrices for 284 nodes
A_initial (284, 284, 1)


In [7]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=200, shuffle=False)
test_loader = DataLoader(test_set, batch_size=200, shuffle=False)

In [8]:
# import datetime
# from datetime import datetime,timezone
# for batch_idx, data in enumerate(test_loader):
#     lambda_uv_t, lambda_uv_t_non_events, [A_pred, Surv], reg = model(data)
#     print('lambda_uv_t', lambda_uv_t)
#     print('lambda_uv_t_non_events', lambda_uv_t_non_events)
#     print('A_pred', A_pred)
#     print('Surv', Surv)
#     print('reg', reg)

In [9]:
def MAR(A_pred, u, v, k, Survival_term):
    '''Computes mean average ranking for a batch of events'''
    ranks = []
    hits_10 = []
    N = len(A_pred)
    Survival_term = torch.exp(-Survival_term)
    A_pred *= Survival_term
    assert torch.sum(torch.isnan(A_pred)) == 0, (torch.sum(torch.isnan(A_pred)), Survival_term.min(), Survival_term.max())

    A_pred = A_pred.data.cpu().numpy()


    assert N == len(u) == len(v) == len(k), (N, len(u), len(v), len(k))
    for b in range(N):
        u_it, v_it = u[b].item(), v[b].item()
        assert u_it != v_it, (u_it, v_it, k[b])
        A = A_pred[b].squeeze()
        # remove same node
        idx1 = list(np.argsort(A[u_it])[::-1])
        idx1.remove(u_it)
        idx2 = list(np.argsort(A[v_it])[::-1])
        idx2.remove(v_it)
        rank1 = np.where(np.array(idx1) == v_it) # get nodes most likely connected to u[b] and find out the rank of v[b] among those nodes
        rank2 = np.where(np.array(idx2) == u_it)  # get nodes most likely connected to v[b] and find out the rank of u[b] among those nodes
        assert len(rank1) == len(rank2) == 1, (len(rank1), len(rank2))
        hits_10.append(np.mean([float(rank1[0] <= 9), float(rank2[0] <= 9)]))
        rank = np.mean([rank1[0], rank2[0]])
        assert isinstance(rank, np.float64), (rank, rank1, rank2, u_it, v_it, idx1, idx2)
        ranks.append(rank)
    return ranks, hits_10

In [10]:
import time
def test(model, n_test_batches=10, epoch=0):
    model.eval()
    loss = 0
    losses =[ [np.Inf, 0], [np.Inf, 0] ]
    n_samples = 0
    # Time slots with 10 days intervals as in the DyRep paper
    timeslots = [t.toordinal() for t in test_loader.dataset.TEST_TIMESLOTS]
    event_types = list(test_loader.dataset.event_types_num.keys()) #['comm', 'assoc']
    # sort it by k
    for event_t in test_loader.dataset.event_types_num:
        event_types[test_loader.dataset.event_types_num[event_t]] = event_t

    event_types += ['Com']

    mar, hits_10 = {}, {}
    for event_t in event_types:
        mar[event_t] = []
        hits_10[event_t] = []
        for c, slot in enumerate(timeslots):
            mar[event_t].append([])
            hits_10[event_t].append([])


    start = time.time()
    with torch.no_grad():
        
        import datetime
        #from datetime import datetime, timezone 
        for batch_idx, data in enumerate(test_loader):
            data[2] = data[2].float()
            data[4] = data[4].double()
            data[5] = data[5].double()
            output = model(data)
            loss += (-torch.sum(torch.log(output[0]) + 1e-10) + torch.sum(output[1])).item()
            for i in range(len(losses)):
                m1 = output[i].min()
                m2 = output[i].max()
                if m1 < losses[i][0]:
                    losses[i][0] = m1
                if m2 > losses[i][1]:
                    losses[i][1] = m2
            n_samples += 1
            A_pred, Survival_term = output[2]
            u, v, k = data[0], data[1], data[3]

            time_cur = data[5]
            m, h = MAR(A_pred, u, v, k, Survival_term=Survival_term)
            assert len(time_cur) == len(m) == len(h) == len(k)
            for t, m, h, k_ in zip(time_cur, m, h, k):
                d = datetime.datetime.fromtimestamp(t.item()).toordinal()
                event_t = event_types[k_.item()]
                for c, slot in enumerate(timeslots):
                    if d <= slot:
                        mar[event_t][c].append(m)
                        hits_10[event_t][c].append(h)
                        if k_ > 0:
                            mar['Com'][c].append(m)
                            hits_10['Com'][c].append(h)
                        if c > 0:
                            assert slot > timeslots[c-1] and d > timeslots[c-1], (d, slot, timeslots[c-1])
                        break

            if batch_idx % 10 == 0:
                print('test', batch_idx)

            if n_test_batches is not None and batch_idx >= n_test_batches - 1:
                break

    time_iter = time.time() - start

    print('\nTEST batch={}/{}, loss={:.3f}, psi={}, loss1 min/max={:.4f}/{:.4f}, '
          'loss2 min/max={:.4f}/{:.4f}, integral time stamps={}, sec/iter={:.4f}'.
          format(batch_idx + 1, len(test_loader), (loss / n_samples),
                 [model.psi[c].item() for c in range(len(model.psi))],
                 losses[0][0], losses[0][1], losses[1][0], losses[1][1],
                 len(model.Lambda_dict), time_iter / (batch_idx + 1)))

    # Report results for different time slots in the test set
    for c, slot in enumerate(timeslots):
        s = 'Slot {}: '.format(c)
        for event_t in event_types:
            sfx = '' if event_t == event_types[-1] else ', '
            if len(mar[event_t][c]) > 0:
                s += '{} ({} events): MAR={:.2f}+-{:.2f}, HITS_10={:.3f}+-{:.3f}'.\
                    format(event_t, len(mar[event_t][c]), np.mean(mar[event_t][c]), np.std(mar[event_t][c]),
                            np.mean(hits_10[event_t][c]), np.std(hits_10[event_t][c]))
            else:
                s += '{} (no events)'.format(event_t)
            s += sfx
        print(s)

    mar_all, hits_10_all = {}, {}
    for event_t in event_types:
        mar_all[event_t] = []
        hits_10_all[event_t] = []
        for c, slot in enumerate(timeslots):
            mar_all[event_t].extend(mar[event_t][c])
            hits_10_all[event_t].extend(hits_10[event_t][c])

    s = 'Epoch {}: results per event type for all test time slots: \n'.format(epoch)
    print(''.join(['-']*100))
    for event_t in event_types:
        if len(mar_all[event_t]) > 0:
            s += '====== {:10s}\t ({:7s} events): \tMAR={:.2f}+-{:.2f}\t HITS_10={:.3f}+-{:.3f}'.\
                format(event_t, str(len(mar_all[event_t])), np.mean(mar_all[event_t]), np.std(mar_all[event_t]),
                       np.mean(hits_10_all[event_t]), np.std(hits_10_all[event_t]))
        else:
            s += '====== {:10s}\t (no events)'.format(event_t)
        if event_t != event_types[-1]:
            s += '\n'
    print(s)
    print(''.join(['-'] * 100))

    return mar_all, hits_10_all, loss / n_samples

In [11]:
print('model', model)

model DyRep(
  (W_h): Linear(in_features=32, out_features=32, bias=True)
  (W_struct): Linear(in_features=32, out_features=32, bias=True)
  (W_rec): Linear(in_features=32, out_features=32, bias=True)
  (W_t): Linear(in_features=4, out_features=32, bias=True)
  (omega): ModuleList(
    (0-1): 2 x Linear(in_features=64, out_features=1, bias=True)
  )
)


In [12]:
params_main = []
for name, param in model.named_parameters():
    if param.requires_grad:
        params_main.append(param)

In [13]:
optimizer = torch.optim.Adam([{"params": params_main, "weight_decay":0}], lr=0.0002, betas=(0.5, 0.999))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, '10', gamma=0.5)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3,4], gamma=0.5)

In [14]:
time_bar = np.zeros((N_nodes, 1)) + train_set.FIRST_DATE.timestamp()

In [15]:
import copy
def get_temporal_variables():
    variables = {}
    variables['time_bar'] = copy.deepcopy(time_bar)
    variables['node_degree_global'] = copy.deepcopy(node_degree_global)
    variables['time_keys'] = copy.deepcopy(model.time_keys)
    variables['z'] = model.z.clone()
    variables['S'] = model.S.clone()
    variables['A'] = model.A.clone()
    variables['Lambda_dict'] = model.Lambda_dict.clone()
    return variables

In [16]:
def set_temporal_variables(variables, model, train_loader, test_loader):
    time_bar = copy.deepcopy(variables['time_bar'])
    train_loader.dataset.time_bar = time_bar
    test_loader.dataset.time_bar = time_bar
    model.node_degree_global = copy.deepcopy(variables['node_degree_global'])
    model.time_keys = copy.deepcopy(variables['time_keys'])
    model.z = variables['z'].clone()
    model.S = variables['S'].clone()
    model.A = variables['A'].clone()
    model.Lambda_dict = variables['Lambda_dict'].clone()
    return time_bar

In [17]:
# import datetime
# from datetime import datetime, timezone 
# for batch, dat in enumerate(train_loader):
#     dat[2] = dat[2].float()
#     dat[4] = dat[4].double()
#     dat[5] = dat[5].double()

In [18]:
def initalize_state(dataset, keepS=False):
        '''Initializes node embeddings and the graph to the original state after every epoch'''
        Adj_all = dataset.get_Adjacency()[0]

        if not isinstance(Adj_all, list):
            Adj_all = [Adj_all]

        node_degree_global = []
        for rel, A in enumerate(Adj_all):
            node_degree_global.append(np.zeros(A.shape[0]))
            for u in range(A.shape[0]):
                node_degree_global[rel][u] = np.sum(A[u])

        Adj_all = Adj_all[0]
        print('Adj_all', Adj_all.shape, len(node_degree_global), node_degree_global[0].min(), node_degree_global[0].max())
        time_bar = np.zeros((dataset.N_nodes, 1)) + dataset.FIRST_DATE.timestamp()

        model.initialize(node_embeddings=initial_embeddings,
                         A_initial=Adj_all, keepS=keepS)  # train_loader.dataset.H_train

        return time_bar, node_degree_global

In [19]:
epoch_start = 1
# 记得改回5
epochs = 1
batch_start = 0
batch_size = 200
weight = 1
log_interval = 20
losses_events, losses_nonevents, losses_KL, losses_sum = [], [], [], []
test_MAR, test_HITS10, test_loss = [], [], []

In [20]:
# import datetime
# from datetime import datetime, timezone
# for batch_idx, data_batch in enumerate(train_loader):
#         # if batch_idx <= batch_start:
#         #     continue

#         model.train()
#         optimizer.zero_grad()

#         # Ensure the data is in the correct format
#         data_batch[2] = data_batch[2].float()
#         data_batch[4] = data_batch[4].double()
#         data_batch[5] = data_batch[5].double()

#         output = model(data_batch)

In [21]:
import datetime
for epoch in range(epoch_start, epochs + 1):
    # Setting the global time_bar for the datasets
    if not (epoch == epoch_start):
        # Reinitialize node embeddings and adjacency matrices, but keep the model parameters intact
        time_bar, node_degree_global = initalize_state(train_loader.dataset, keepS=epoch > 1)
        model.node_degree_global = node_degree_global
        
    # Setting the global time_bar for the datasets
    train_loader.dataset.time_bar = time_bar
    test_loader.dataset.time_bar = time_bar

    start = time.time()

    for batch_idx, data_batch in enumerate(train_loader):
        # if batch_idx <= batch_start:
        #   continue

        model.train()
        optimizer.zero_grad()

        # Ensure the data is in the correct format
        data_batch[2] = data_batch[2].float()
        data_batch[4] = data_batch[4].double()
        data_batch[5] = data_batch[5].double()

        output = model(data_batch)
        losses = [-torch.sum(torch.log(output[0]) + 1e-10), weight * torch.sum(output[1])]

        # KL losses (if there are additional items in output to process as losses)
        if len(output) > 3 and output[-1] is not None:
            losses.extend(output[-1])

        loss = torch.sum(torch.stack(losses)) / batch_size
        loss.backward()
        nn.utils.clip_grad_value_(model.parameters(), 100)
        optimizer.step()

        losses_events.append(losses[0].item())
        losses_nonevents.append(losses[1].item())
        losses_sum.append(loss.item())

        # Clamping psi to prevent numerical overflow
        model.psi.data = torch.clamp(model.psi.data, 1e-1, 1e+3)

        time_iter = time.time() - start

        # Detach computational graph to prevent unwanted backprop
        model.z = model.z.detach()
        model.S = model.S.detach()

        if (batch_idx + 1) % log_interval == 0 or batch_idx == len(train_loader) - 1:
            print(f'\nTRAIN epoch={epoch}/{epochs}, batch={batch_idx + 1}/{len(train_loader)}, '
                  f'sec/iter: {time_iter / (batch_idx + 1):.4f}, loss={loss.item():.3f}, '
                  f'loss components: {[l.item() for l in losses]}')

            # Save state before testing
            variables = get_temporal_variables()
            print('time', datetime.datetime.fromtimestamp(np.max(time_bar)))

            # Testing and collecting results
            result = test(model, n_test_batches=None if batch_idx == len(train_loader) - 1 else 10, epoch=epoch)
            test_MAR.append(np.mean(result[0]['Com']))
            test_HITS10.append(np.mean(result[1]['Com']))
            test_loss.append(result[2])

            # Restore state after testing
            time_bar = set_temporal_variables(variables, model, train_loader, test_loader)

    scheduler.step()

print('end time:', datetime.datetime.now())





TRAIN epoch=1/1, batch=20/59, sec/iter: 19.7709, loss=3.533, loss components: [300.87066650390625, 405.7894287109375]
time 2013-04-10 14:13:33


  hits_10.append(np.mean([float(rank1[0] <= 9), float(rank2[0] <= 9)]))


test 0

TEST batch=10/46, loss=717.698, psi=[0.503119170665741, 0.5036543607711792], loss1 min/max=0.0328/1.1677, loss2 min/max=0.0013/0.3706, integral time stamps=5000, sec/iter=37.6026
Slot 0: FollowEvent (2 events): MAR=57.00+-40.00, HITS_10=0.250+-0.250, ForkEvent (2 events): MAR=77.25+-22.25, HITS_10=0.000+-0.000, PushEvent (15 events): MAR=71.37+-39.49, HITS_10=0.000+-0.000, WatchEvent (26 events): MAR=108.37+-66.99, HITS_10=0.058+-0.160, IssuesEvent (2 events): MAR=71.25+-35.25, HITS_10=0.500+-0.000, IssueCommentEvent (4 events): MAR=74.75+-39.47, HITS_10=0.000+-0.000, PullRequestEvent (5 events): MAR=70.20+-30.89, HITS_10=0.000+-0.000, CommitCommentEvent (no events), Com (54 events): MAR=89.54+-56.51, HITS_10=0.046+-0.145
Slot 1: FollowEvent (109 events): MAR=114.84+-52.15, HITS_10=0.124+-0.216, ForkEvent (101 events): MAR=95.35+-42.80, HITS_10=0.099+-0.199, PushEvent (75 events): MAR=86.17+-51.39, HITS_10=0.040+-0.136, WatchEvent (930 events): MAR=85.40+-45.79, HITS_10=0.195+-