Skip to content
Permalink
Browse files

added support jodie library codes

  • Loading branch information...
srijankr
srijankr committed Jul 14, 2019
1 parent c59ad2b commit 7938561e81fb2d18efcca73db084c99bc1922204
Showing with 332 additions and 0 deletions.
  1. +107 −0 library_data.py
  2. +225 −0 library_models.py
@@ -0,0 +1,107 @@
'''
This is a supporting library for the loading the data.
Paper: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks. S. Kumar, X. Zhang, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2019.
'''

from __future__ import division
import numpy as np
import random
import sys
import operator
import copy
from collections import defaultdict
import os, re
import cPickle
import argparse
from sklearn.preprocessing import scale

# LOAD THE NETWORK
def load_network(args, time_scaling=True):
'''
This function loads the input network.
The network should be in the following format:
One line per interaction/edge.
Each line should be: user, item, timestamp, state label, array of features.
Timestamp should be in cardinal format (not in datetime).
State label should be 1 whenever the user state changes, 0 otherwise. If there are no state labels, use 0 for all interactions.
Feature list can be as long as desired. It should be atleast 1 dimensional. If there are no features, use 0 for all interactions.
'''

network = args.network
datapath = args.datapath

user_sequence = []
item_sequence = []
label_sequence = []
feature_sequence = []
timestamp_sequence = []
start_timestamp = None
y_true_labels = []

print "\n\n**** Loading %s network from file: %s ****" % (network, datapath)
f = open(datapath,"r")
f.readline()
for cnt, l in enumerate(f):
# FORMAT: user, item, timestamp, state label, feature list
ls = l.strip().split(",")
user_sequence.append(ls[0])
item_sequence.append(ls[1])
if start_timestamp is None:
start_timestamp = float(ls[2])
timestamp_sequence.append(float(ls[2]) - start_timestamp)
y_true_labels.append(int(ls[3])) # label = 1 at state change, 0 otherwise
feature_sequence.append(map(float,ls[4:]))
f.close()

user_sequence = np.array(user_sequence)
item_sequence = np.array(item_sequence)
timestamp_sequence = np.array(timestamp_sequence)

print "Formating item sequence"
nodeid = 0
item2id = {}
item_timedifference_sequence = []
item_current_timestamp = defaultdict(float)
for cnt, item in enumerate(item_sequence):
if item not in item2id:
item2id[item] = nodeid
nodeid += 1
timestamp = timestamp_sequence[cnt]
item_timedifference_sequence.append(timestamp - item_current_timestamp[item])
item_current_timestamp[item] = timestamp
num_items = len(item2id)
item_sequence_id = [item2id[item] for item in item_sequence]

print "Formating user sequence"
nodeid = 0
user2id = {}
user_timedifference_sequence = []
user_current_timestamp = defaultdict(float)
user_previous_itemid_sequence = []
user_latest_itemid = defaultdict(lambda: num_items)
for cnt, user in enumerate(user_sequence):
if user not in user2id:
user2id[user] = nodeid
nodeid += 1
timestamp = timestamp_sequence[cnt]
user_timedifference_sequence.append(timestamp - user_current_timestamp[user])
user_current_timestamp[user] = timestamp
user_previous_itemid_sequence.append(user_latest_itemid[user])
user_latest_itemid[user] = item2id[item_sequence[cnt]]
num_users = len(user2id)
user_sequence_id = [user2id[user] for user in user_sequence]

if time_scaling:
print "Scaling timestamps"
user_timedifference_sequence = scale(np.array(user_timedifference_sequence) + 1)
item_timedifference_sequence = scale(np.array(item_timedifference_sequence) + 1)

print "*** Network loading completed ***\n\n"
return [user2id, user_sequence_id, user_timedifference_sequence, user_previous_itemid_sequence, \
item2id, item_sequence_id, item_timedifference_sequence, \
timestamp_sequence, \
feature_sequence, \
y_true_labels]

@@ -0,0 +1,225 @@
'''
This is a supporting library with the code of the model.
Paper: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks. S. Kumar, X. Zhang, J. Leskovec. ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2019.
'''

from __future__ import division
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch import optim
import numpy as np
import math, random
import sys
from collections import defaultdict
import os
import cPickle
import gpustat
from itertools import chain
from tqdm import tqdm, trange, tqdm_notebook, tnrange
import getpass
import csv

PATH = "./"

try:
get_ipython
trange = tnrange
tqdm = tqdm_notebook
except NameError:
pass

total_reinitialization_count = 0

# A NORMALIZATION LAYER
class NormalLinear(nn.Linear):
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.normal_(0, stdv)
if self.bias is not None:
self.bias.data.normal_(0, stdv)


# THE JODIE MODULE
class JODIE(nn.Module):
def __init__(self, args, num_features, num_users, num_items):
super(JODIE,self).__init__()

print "*** Initializing the JODIE model ***"
self.modelname = args.model
self.embedding_dim = args.embedding_dim
self.num_users = num_users
self.num_items = num_items
self.user_static_embedding_size = num_users
self.item_static_embedding_size = num_items

print "Initializing user and item embeddings"
self.initial_user_embedding = nn.Parameter(torch.Tensor(args.embedding_dim))
self.initial_item_embedding = nn.Parameter(torch.Tensor(args.embedding_dim))

rnn_input_size_items = rnn_input_size_users = self.embedding_dim + 1 + num_features

print "Initializing user and item RNNs"
self.item_rnn = nn.RNNCell(rnn_input_size_users, self.embedding_dim)
self.user_rnn = nn.RNNCell(rnn_input_size_items, self.embedding_dim)

print "Initializing linear layers"
self.linear_layer1 = nn.Linear(self.embedding_dim, 50)
self.linear_layer2 = nn.Linear(50, 2)
self.prediction_layer = nn.Linear(self.user_static_embedding_size + self.item_static_embedding_size + self.embedding_dim * 2, self.item_static_embedding_size + self.embedding_dim)
self.embedding_layer = NormalLinear(1, self.embedding_dim)
print "*** JODIE initialization complete ***\n\n"

def forward(self, user_embeddings, item_embeddings, timediffs=None, features=None, select=None):
if select == 'item_update':
input1 = torch.cat([user_embeddings, timediffs, features], dim=1)
item_embedding_output = self.item_rnn(input1, item_embeddings)
return F.normalize(item_embedding_output)

elif select == 'user_update':
input2 = torch.cat([item_embeddings, timediffs, features], dim=1)
user_embedding_output = self.user_rnn(input2, user_embeddings)
return F.normalize(user_embedding_output)

elif select == 'project':
user_projected_embedding = self.context_convert(user_embeddings, timediffs, features)
#user_projected_embedding = torch.cat([input3, item_embeddings], dim=1)
return user_projected_embedding

def context_convert(self, embeddings, timediffs, features):
new_embeddings = embeddings * (1 + self.embedding_layer(timediffs))
return new_embeddings

def predict_label(self, user_embeddings):
X_out = nn.ReLU()(self.linear_layer1(user_embeddings))
X_out = self.linear_layer2(X_out)
return X_out

def predict_item_embedding(self, user_embeddings):
X_out = self.prediction_layer(user_embeddings)
return X_out


# INITIALIZE T-BATCH VARIABLES
def reinitialize_tbatches():
global current_tbatches_interactionids, current_tbatches_user, current_tbatches_item, current_tbatches_timestamp, current_tbatches_feature, current_tbatches_label, current_tbatches_previous_item
global tbatchid_user, tbatchid_item, current_tbatches_user_timediffs, current_tbatches_item_timediffs, current_tbatches_user_timediffs_next

# list of users of each tbatch up to now
current_tbatches_interactionids = defaultdict(list)
current_tbatches_user = defaultdict(list)
current_tbatches_item = defaultdict(list)
current_tbatches_timestamp = defaultdict(list)
current_tbatches_feature = defaultdict(list)
current_tbatches_label = defaultdict(list)
current_tbatches_previous_item = defaultdict(list)
current_tbatches_user_timediffs = defaultdict(list)
current_tbatches_item_timediffs = defaultdict(list)
current_tbatches_user_timediffs_next = defaultdict(list)

# the latest tbatch a user is in
tbatchid_user = defaultdict(lambda: -1)

# the latest tbatch a item is in
tbatchid_item = defaultdict(lambda: -1)

global total_reinitialization_count
total_reinitialization_count +=1


# CALCULATE LOSS FOR THE PREDICTED USER STATE
def calculate_state_prediction_loss(model, tbatch_interactionids, user_embeddings_time_series, y_true, loss_function):
# PREDCIT THE LABEL FROM THE USER DYNAMIC EMBEDDINGS
prob = model.predict_label(user_embeddings_time_series[tbatch_interactionids,:])
y = Variable(torch.LongTensor(y_true).cuda()[tbatch_interactionids])

loss = loss_function(prob, y)

return loss


# SAVE TRAINED MODEL TO DISK
def save_model(model, optimizer, args, epoch, user_embeddings, item_embeddings, train_end_idx, user_embeddings_time_series=None, item_embeddings_time_series=None, path=PATH):
print "*** Saving embeddings and model ***"
state = {
'user_embeddings': user_embeddings.data.cpu().numpy(),
'item_embeddings': item_embeddings.data.cpu().numpy(),
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
'train_end_idx': train_end_idx
}

if user_embeddings_time_series is not None:
state['user_embeddings_time_series'] = user_embeddings_time_series.data.cpu().numpy()
state['item_embeddings_time_series'] = item_embeddings_time_series.data.cpu().numpy()

directory = os.path.join(path, 'saved_models/%s' % args.network)
if not os.path.exists(directory):
os.makedirs(directory)

filename = os.path.join(directory, "checkpoint.%s.ep%d.tp%.1f.pth.tar" % (args.model, epoch, args.train_proportion))
torch.save(state, filename)
print "*** Saved embeddings and model to file: %s ***\n\n" % filename


# LOAD PREVIOUSLY TRAINED AND SAVED MODEL
def load_model(model, optimizer, args, epoch):
modelname = args.model
filename = PATH + "saved_models/%s/checkpoint.%s.ep%d.tp%.1f.pth.tar" % (args.network, modelname, epoch, args.train_proportion)
checkpoint = torch.load(filename)
print "Loading saved embeddings and model: %s" % filename
args.start_epoch = checkpoint['epoch']
user_embeddings = Variable(torch.from_numpy(checkpoint['user_embeddings']).cuda())
item_embeddings = Variable(torch.from_numpy(checkpoint['item_embeddings']).cuda())
try:
train_end_idx = checkpoint['train_end_idx']
except KeyError:
train_end_idx = None

try:
user_embeddings_time_series = Variable(torch.from_numpy(checkpoint['user_embeddings_time_series']).cuda())
item_embeddings_time_series = Variable(torch.from_numpy(checkpoint['item_embeddings_time_series']).cuda())
except:
user_embeddings_time_series = None
item_embeddings_time_series = None

model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

return [model, optimizer, user_embeddings, item_embeddings, user_embeddings_time_series, item_embeddings_time_series, train_end_idx]


# SET USER AND ITEM EMBEDDINGS TO THE END OF THE TRAINING PERIOD
def set_embeddings_training_end(user_embeddings, item_embeddings, user_embeddings_time_series, item_embeddings_time_series, user_data_id, item_data_id, train_end_idx):
userid2lastidx = {}
for cnt, userid in enumerate(user_data_id[:train_end_idx]):
userid2lastidx[userid] = cnt
itemid2lastidx = {}
for cnt, itemid in enumerate(item_data_id[:train_end_idx]):
itemid2lastidx[itemid] = cnt

try:
embedding_dim = user_embeddings_time_series.size(1)
except:
embedding_dim = user_embeddings_time_series.shape[1]
for userid in userid2lastidx:
user_embeddings[userid, :embedding_dim] = user_embeddings_time_series[userid2lastidx[userid]]
for itemid in itemid2lastidx:
item_embeddings[itemid, :embedding_dim] = item_embeddings_time_series[itemid2lastidx[itemid]]

user_embeddings.detach_()
item_embeddings.detach_()


# SELECT THE GPU WITH MOST FREE MEMORY TO SCHEDULE JOB
def select_free_gpu():
mem = []
gpus = list(set([0,1,2,3]))
for i in gpus:
gpu_stats = gpustat.GPUStatCollection.new_query()
mem.append(gpu_stats.jsonify()["gpus"][i]["memory.used"])
return str(gpus[np.argmin(mem)])

0 comments on commit 7938561

Please sign in to comment.
You can’t perform that action at this time.