# params.py

In [1]:
import argparse
import torch

def parse_args():
	parser = argparse.ArgumentParser(description='Model Params')
	parser.add_argument('--lr', default=1.5e-3, type=float, help='learning rate')
	parser.add_argument('--batch', default=256, type=int, help='batch size')
	parser.add_argument('--reg', default=3e-2, type=float, help='weight decay regularizer')
	parser.add_argument('--epoch', default=200, type=int, help='number of epochs')
	parser.add_argument('--decay', default=0.96, type=float, help='weight decay rate')
	parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
	parser.add_argument('--latdim', default=32, type=int, help='embedding size')
	parser.add_argument('--rank', default=4, type=int, help='embedding size')
	parser.add_argument('--memosize', default=2, type=int, help='memory size')
	parser.add_argument('--sampNum', default=40, type=int, help='batch size for sampling')
	parser.add_argument('--att_head', default=2, type=int, help='number of attention heads')
	parser.add_argument('--gnn_layer', default=2, type=int, help='number of gnn layers')
	parser.add_argument('--hyperNum', default=128, type=int, help='number of hyper edges')
	parser.add_argument('--trnNum', default=10000, type=int, help='number of training instances per epoch')
	parser.add_argument('--load_model', default=None, help='model name to load')
	parser.add_argument('--shoot', default=20, type=int, help='K of top k')
	parser.add_argument('--data', default='yelp', type=str, help='name of dataset')
	parser.add_argument('--mult', default=100, type=float, help='multiplier for the result')
	parser.add_argument('--keepRate', default=0.5, type=float, help='rate for dropout')
	parser.add_argument('--slot', default=5, type=float, help='length of time slots')
	parser.add_argument('--divSize', default=10000, type=int, help='div size for smallTestEpoch')
	parser.add_argument('--tstEpoch', default=10, type=int, help='number of epoch to test while training')
	parser.add_argument('--leaky', default=0.5, type=float, help='slope for leaky relu')
	parser.add_argument('--gcn_hops', default=2, type=int, help='number of hops in gcn precessing')
	parser.add_argument('--ssl_reg', default=1e-4, type=float, help='reg weight for ssl loss')
	parser.add_argument('--edgeSampRate', default=0.5, type=float, help='Ratio of sampled edges')
	return parser
args, _ = parse_args().parse_known_args()
args.decay_step = args.trnNum//args.batch
if torch.cuda.is_available():
	args.device = "cuda"
else:
	args.device = "cpu"

# Dataloader.py

In [2]:
import scipy.sparse as sp
from scipy.sparse import coo_matrix, csr_matrix
import torch
import numpy as np
import pickle
#from Params import args

def transpose(mat):
  coomat = sp.coo_matrix(mat)
  return csr_matrix(coomat.transpose())

def negSamp(temLabel, sampSize, nodeNum):
  negset = [None] * sampSize
  cur = 0
  while cur < sampSize:
    rdmItm = np.random.choice(nodeNum)
    if temLabel[rdmItm] == 0:
      negset[cur] = rdmItm
      cur += 1
  return negset

def transToLsts(mat, mask=False, norm=False):
  shape = torch.Size(mat.shape)
  mat = sp.coo_matrix(mat)
  indices = torch.from_numpy(np.vstack((mat.row, mat.col)).astype(np.int64))
  data = mat.data
  
  if norm:
    rowD = np.squeeze(np.array(1 / (np.sqrt(np.sum(mat, axis=1) + 1e-8) + 1e-8)))
    colD = np.squeeze(np.array(1 / (np.sqrt(np.sum(mat, axis=0) + 1e-8) + 1e-8)))
    for i in range(len(mat.data)):
      row = indices[0, i]
      col = indices[1, i]
      data[i] = data[i] * rowD[row] * colD[col]
  # half mask
  if mask:
    spMask = (np.random.uniform(size=data.shape) > 0.5) * 1.0
    data = data * spMask

  if indices.shape[0] == 0:
    indices = np.array([[0, 0]], dtype=np.int32)
    data = np.array([0.0], np.float32)

  data = torch.from_numpy(data)
  #a =torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda()
  return indices, data, shape

class DataHandler:
  def __init__(self):
    if args.data == 'yelp':
      predir = '/content/drive/MyDrive/SHT/yelp/'
    elif args.data == 'tmall':
      predir = '/content/drive/MyDrive/SHT/tmall/'
    elif args.data == 'gowalla':
      predir = '/content/drive/MyDrive/SHT/gowalla/'
    self.predir = predir
    self.trnfile = predir + 'trnMat.pkl'
    self.tstfile = predir + 'tstMat.pkl'

  def LoadData(self):
    with open(self.trnfile, 'rb') as fs:
      trnMat = (pickle.load(fs) != 0).astype(np.float32)
    # test set
    with open(self.tstfile, 'rb') as fs:
      tstMat = pickle.load(fs)
    tstLocs = [None] * tstMat.shape[0]
    tstUsrs = set()
    for i in range(len(tstMat.data)):
      row = tstMat.row[i]
      col = tstMat.col[i]
      if tstLocs[row] is None:
        tstLocs[row] = list()
      tstLocs[row].append(col)
      tstUsrs.add(row)
    tstUsrs = np.array(list(tstUsrs))

    self.trnMat = trnMat
    self.tstLocs = tstLocs
    self.tstUsrs = tstUsrs
    args.edgeNum = len(trnMat.data)
    args.user, args.item = self.trnMat.shape

# TimeLogger.py

In [3]:
import datetime

logmsg = ''
timemark = dict()
saveDefault = False
def log(msg, save=None, oneline=False):
	global logmsg
	global saveDefault
	time = datetime.datetime.now()
	tem = '%s: %s' % (time, msg)
	if save != None:
		if save:
			logmsg += tem + '\n'
	elif saveDefault:
		logmsg += tem + '\n'
	if oneline:
		print(tem, end='\r')
	else:
		print(tem)

def marktime(marker):
	global timemark
	timemark[marker] = datetime.datetime.now()

def SpentTime(marker):
	global timemark
	if marker not in timemark:
		msg = 'LOGGER ERROR, marker', marker, ' not found'
		tem = '%s: %s' % (time, msg)
		print(tem)
		return False
	return datetime.datetime.now() - timemark[marker]

def SpentTooLong(marker, day=0, hour=0, minute=0, second=0):
	global timemark
	if marker not in timemark:
		msg = 'LOGGER ERROR, marker', marker, ' not found'
		tem = '%s: %s' % (time, msg)
		print(tem)
		return False
	return datetime.datetime.now() - timemark[marker] >= datetime.timedelta(days=day, hours=hour, minutes=minute, seconds=second)

if __name__ == '__main__':
	log('')

2022-09-08 03:21:15.323094: 


# Model.py

In [4]:
import torch
import torch.nn as nn
#from Params import args

torch.manual_seed(666)
np.random.seed(666)

class FC(nn.Module):
  def __init__(self, inputdim, outputdim, bias = False):
    super(FC, self).__init__()
    initializer = nn.init.xavier_normal_
    self.W_fc = nn.Parameter(initializer(torch.empty(inputdim, outputdim).cuda())) # shape latdim * latdim
    self.bias = False
    if bias is True:
      initializer = nn.init.zeros_
      self.bias_fc = nn.Parameter(initializer(torch.empty(outputdim).cuda()))
      self.bias = bias

  def forward(self, ret, act = None):
    ret = ret @ self.W_fc
    if self.bias is True:
      ret = ret + self.bias_fc
    if act == 'leakyrelu':
      ret = torch.maximum(args.leaky * ret, ret)
    if act == 'sigmoid':
      ret = torch.sigmoid(ret)
    return ret

class propagate(nn.Module):
  def __init__(self):
    super(propagate, self).__init__()
    initializer = nn.init.xavier_normal_

    self.fc1 = FC(args.hyperNum,args.hyperNum)
    self.fc2 = FC(args.hyperNum,args.hyperNum)

  def forward(self, V, lats, key, hyper):
    lstlat = torch.reshape(lats[-1] @ V, [-1, args.att_head, args.latdim // args.att_head])
    lstlat = torch.permute(lstlat, (1,2,0)) #shape head_num * (latdim/head_num) * (user num or item num)
    temlat1 = lstlat @ key # shape head_num * (latdim/head_num) * (latdim/head_num)
    hyper = torch.reshape(hyper, [-1, args.att_head, args.latdim // args.att_head])
    hyper = torch.permute(hyper, (1,2,0)) #shape head_num * (latdim/head_num) * hyperNum
    temlat1 = torch.reshape(temlat1 @ hyper, [args.latdim, -1]) #shape latdim * hyperNum
    temlat2 = self.fc1(temlat1, act = 'leakyrelu') + temlat1 #shape latdim * hyperNum
    temlat3 = self.fc2(temlat2, act = 'leakyrelu') + temlat2 #shape latdim * hyperNum

    preNewLat = torch.reshape(torch.transpose(temlat3, 0, 1) @ V, [-1, args.att_head, args.latdim//args.att_head]) #shape hyperNum * head_num * (latdim/head_num)
    preNewLat = torch.permute(preNewLat, (1,0,2)) #shape head num * hyperNum * latdim/head_num
    preNewLat = hyper @ preNewLat #shape head_num * (latdim/head_num) * (latdim/head_num)
    newLat = key @ preNewLat #shape head_num * user num or item num * (latdim/head_num)
    newLat = torch.reshape(torch.permute(newLat,(1,0,2)),[-1,args.latdim]) #shape user num or item num * latdim
    lats.append(newLat)

class meta(nn.Module):
  def __init__(self):
    super(meta, self).__init__()
    self.fc1 = FC(args.latdim, args.latdim * args.latdim, bias = True)
    self.fc2 = FC(args.latdim, args.latdim, bias = True)
    #self.actFunc = nn.LeakyReLU(negative_slope=args.leaky)

  def forward(self, hyper, key):
    hyper_mean = torch.mean(hyper, dim=0, keepdim=True) #1 * latdim
    W1 = self.fc1(hyper_mean, act = None)  # 1 * (latdim*latdim)
    W1 = torch.reshape(W1, [args.latdim, args.latdim]) # latdim * latdim
    b1 = self.fc2(hyper_mean, act = None) # 1 * latdim
    ret = key @ W1 + b1
    ret = torch.maximum(args.leaky * ret, ret) # (batchsize * latdim) * (latdim * latdim) + 1*latdim = batchsize * latdim // 534564 ?, 32
    return ret

class SHT(nn.Module):
  def __init__(self, adj, tpadj):
    super(SHT, self).__init__()

    initializer = nn.init.xavier_normal_
    self.adj = adj # user * item
    self.tpadj = tpadj # item * user

    self.uEmbed_ini = nn.Parameter(initializer(torch.empty(args.user,args.latdim).cuda())) #shape user * latdim
    self.iEmbed_ini = nn.Parameter(initializer(torch.empty(args.item,args.latdim).cuda())) #shape item * latdim
    self.uHyper = nn.Parameter(initializer(torch.empty(args.hyperNum,args.latdim).cuda())) #shape hyper num * latdim
    self.iHyper = nn.Parameter(initializer(torch.empty(args.hyperNum,args.latdim).cuda())) #shape hyper num * latdim

    #BUG! only one <K> and one <V> is needed.

    self.K = nn.Parameter(initializer(torch.empty(args.latdim, args.latdim).cuda())) # shape latdim * latdim
    self.V = nn.Parameter(initializer(torch.empty(args.latdim, args.latdim).cuda())) # shape latdim * latdim

    self.user_propagate = nn.ModuleList()
    self.item_propagate = nn.ModuleList()

    self.reg = []

    for i in range(args.gnn_layer):
      self.user_propagate.append(propagate())#output : shape user num * latdim
      self.item_propagate.append(propagate())#output : shape item num * latdim
    
    self.fc1_label = FC(2 * args.latdim, args.latdim, bias = True)
    self.fc2_label = FC(args.latdim, 1, bias = True)

    #BUG! only one <meta> is needed. 
    self.meta = meta()

  def prepareKey(self, nodeEmbed):
    key = torch.reshape(nodeEmbed @ self.K, [-1, args.att_head, args.latdim // args.att_head])
    key = torch.permute(key, (1,0,2)) #shape head_num * (user num or item num) * (latdim/head_num)
    return key
  
  def label(self, usrKey, itmKey, uHyper, iHyper):
    ulat = self.meta(uHyper, usrKey) # batchsize * latdim
    ilat = self.meta(iHyper, itmKey) # batchsize * latdim
    lat = torch.cat([ulat, ilat], dim=-1) # batchsize * 2latdim
    lat = self.fc1_label(lat, act = 'leakyrelu') #batchsize * latdim
    lat = lat + ulat + ilat
    ret = self.fc2_label(lat, act = 'sigmoid')
    ret = torch.reshape(ret, [-1]) #降维
    return ret

  def GCN(self, ulat, ilat, adj, tpadj):
    ulats = [ulat] #shape user * latdim
    ilats = [ilat] #shape item * latdim
    for i in range(args.gcn_hops):
      temulat = torch.sparse.mm(adj,ilats[-1]) #shape user * latdim  //sparse
      temilat = torch.sparse.mm(tpadj,ulats[-1]) #shape item * latdim  //sparse
      ulats.append(temulat) 
      ilats.append(temilat)
    ulats_sum = sum(ulats[1:]) #shape user * latdim
    ilats_sum = sum(ilats[1:]) #shape item * latdim
    return ulats_sum, ilats_sum

  def Regularize(self, reg, method = 'L2'):
    ret = 0.0
    for i in range(len(reg)):
      ret += torch.sum(torch.square(reg[i]))
    return ret

  def forward_test(self):
    uEmbed_gcn, iEmbed_gcn = self.GCN(self.uEmbed_ini, self.iEmbed_ini, self.adj, self.tpadj) # usre * latdim, item * latdim
    uEmbed0 = self.uEmbed_ini + uEmbed_gcn
    iEmbed0 = self.iEmbed_ini + iEmbed_gcn

    uKey = self.prepareKey(uEmbed0) #shape head_num * (user num) * (latdim/head_num)
    iKey = self.prepareKey(iEmbed0) #shape head_num * (item num) * (latdim/head_num)

    ulats = [uEmbed0]
    ilats = [iEmbed0]
    for i in range(args.gnn_layer):
      self.user_propagate[i](self.V, ulats, uKey, self.uHyper)
      self.item_propagate[i](self.V, ilats, iKey, self.iHyper)
    
    ulat = sum(ulats) #shape user * latdim
    ilat = sum(ilats) #shape item * latdim
    return ulat, ilat

  def forward(self, uid, iid, edgeids):
    #self.reg.extend([self.uEmbed_ini,self.iEmbed_ini,self.uHyper,self.iHyper])
    uEmbed_gcn, iEmbed_gcn = self.GCN(self.uEmbed_ini, self.iEmbed_ini, self.adj, self.tpadj) # usre * latdim, item * latdim
    uEmbed0 = self.uEmbed_ini + uEmbed_gcn
    iEmbed0 = self.iEmbed_ini + iEmbed_gcn

    #self.gcnNorm = (torch.sum(torch.sum(torch.square(uEmbed_gcn), dim = -1)) + torch.sum(torch.sum(torch.square(iEmbed_gcn), dim = -1))) / 2
    #self.iniNorm = (torch.sum(torch.sum(torch.square(self.uEmbed_ini), dim = -1)) + torch.sum(torch.sum(torch.square(self.iEmbed_ini), dim = -1))) / 2
    uKey = self.prepareKey(uEmbed0) #shape head_num * (user num) * (latdim/head_num)
    iKey = self.prepareKey(iEmbed0) #shape head_num * (item num) * (latdim/head_num)
    #self.reg.append(self.K)
    ulats = [uEmbed0]
    ilats = [iEmbed0]
    for i in range(args.gnn_layer):
      self.user_propagate[i](self.V, ulats, uKey, self.uHyper)
      self.item_propagate[i](self.V, ilats, iKey, self.iHyper)
    
    ulat = sum(ulats) #shape user * latdim
    ilat = sum(ilats) #shape item * latdim
    pckUlat = torch.index_select(ulat, 0, uid.int()) #shape batch size * latdim
    pckIlat = torch.index_select(ilat, 0, iid.int()) #shape batch size * latdim
    preds = torch.sum(pckUlat * pckIlat, dim = -1)

    idx = self.adj._indices() #shape (2, user * item)
    
    users, items = torch.index_select(idx[0,], 0, edgeids.int()), torch.index_select(idx[1,], 0, edgeids.int()) #shape (batchsize)
    uKey = torch.reshape(torch.permute(uKey, (1,0,2)), [-1,args.latdim]) # user num * latdim
    iKey = torch.reshape(torch.permute(iKey, (1,0,2)), [-1,args.latdim]) # item num * latdim
    userKey = torch.index_select(uKey, 0, users.int())  # batchsize * latdim
    itemKey = torch.index_select(iKey, 0, items.int())  # batchsize * latdim

    scores = self.label(userKey, itemKey, self.uHyper, self.iHyper)
    _preds = torch.sum(torch.index_select(uEmbed0, 0, users.int()) * torch.index_select(iEmbed0, 0 , items.int()), dim = -1)
    halfNum = scores.shape[0]//2
    fstScores = scores[:halfNum]
    scdScores = scores[halfNum:]
    fstPreds = _preds[:halfNum]
    scdPreds = _preds[halfNum:]

    sslLoss = torch.sum(torch.maximum(torch.Tensor([0.0]).cuda(), 1.0 - (fstPreds - scdPreds) * args.mult * (fstScores - scdScores)))
    
    reg = [self.uEmbed_ini,self.iEmbed_ini,self.uHyper,\
        self.iHyper,self.K,self.V,self.fc1_label.W_fc,\
        self.fc2_label.W_fc,self.meta.fc1.W_fc,\
        self.meta.fc2.W_fc]

    return preds, sslLoss, self.Regularize(reg, method = 'L2')

# SHT.py

In [5]:
import torch
import numpy as np

# from Model import HCCF
# from DataHandler import DataHandler, negSamp, transToLsts, transpose
# from Params import args
# from TimeLogger import log

class sht():
  def __init__(self,handler):
    self.handler = handler
    self.handler.LoadData()

    adj = handler.trnMat
    idx, data, shape = transToLsts(adj, norm=True)
    self.adj_py = torch.sparse.FloatTensor(idx, data, shape).to(torch.float32).cuda()
    idx, data, shape = transToLsts(transpose(adj), norm=True)
    self.tpAdj_py = torch.sparse.FloatTensor(idx, data, shape).to(torch.float32).cuda()

    self.curepoch = 0
    self.metrics = dict()
    mets = ['Loss', 'preLoss', 'Recall', 'NDCG']
    for met in mets:
      self.metrics['Train' + met] = list()
      self.metrics['Test' + met] = list()

  def preparemodel(self):
    self.model = SHT(self.adj_py, self.tpAdj_py).cuda()
    self.opt = torch.optim.Adam(params = self.model.parameters(), lr=args.lr)
    self.scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer = self.opt, gamma=args.decay)
    print('our training parameters:')
    for name, param in self.model.named_parameters():
        if param.requires_grad:
            print(name,param.shape,param.dtype)

  def sampleTrainBatch(self, batIds, labelMat):
    temLabel = labelMat[batIds].toarray()
    batch = len(batIds)
    temlen = batch * 2 * args.sampNum
    uLocs = [None] * temlen
    iLocs = [None] * temlen
    cur = 0
    for i in range(batch):
      posset = np.reshape(np.argwhere(temLabel[i]!=0), [-1])
      sampNum = min(args.sampNum, len(posset))
      if sampNum == 0:
        poslocs = [np.random.choice(args.item)]
        neglocs = [poslocs[0]]
      else:
        poslocs = np.random.choice(posset, sampNum)
        neglocs = negSamp(temLabel[i], sampNum, args.item)
      for j in range(sampNum):
        posloc = poslocs[j]
        negloc = neglocs[j]
        uLocs[cur] = uLocs[cur+temlen//2] = batIds[i]
        iLocs[cur] = posloc
        iLocs[cur+temlen//2] = negloc
        cur += 1
    uLocsa = uLocs[:cur] + uLocs[temlen//2: temlen//2 + cur]
    iLocsa = iLocs[:cur] + iLocs[temlen//2: temlen//2 + cur]

    edgeSampNum = int(args.edgeSampRate * args.edgeNum)
    if edgeSampNum % 2 == 1:
      edgeSampNum += 1
    edgeids = np.random.choice(args.edgeNum, edgeSampNum)
    
    return torch.Tensor(uLocsa).cuda(), torch.Tensor(iLocsa).cuda(), torch.Tensor(edgeids).cuda()

  def trainEpoch(self):
    num = args.user
    sfIds = np.random.permutation(num)[:args.trnNum]
    epochLoss, epochPreLoss, epochsslLoss, epochregLoss = [0] * 4
    num = len(sfIds)
    steps = int(np.ceil(num / args.batch))

    self.model.train()
    for i in range(steps):
      st = i * args.batch
      ed = min((i+1) * args.batch, num)
      batIds = sfIds[st: ed]

      uLocs, iLocs, edgeids = self.sampleTrainBatch(batIds, self.handler.trnMat)
      preds, sslLoss, regularize = self.model(uLocs, iLocs, edgeids)

      sampNum = uLocs.shape[0] // 2
      posPred = preds[:sampNum]
      negPred = preds[sampNum:sampNum * 2]
      preLoss = torch.sum(torch.maximum(torch.Tensor([0.0]).cuda(), 1.0 - (posPred - negPred))) / args.batch
      regLoss = args.reg * regularize
      regsslLoss = args.ssl_reg * sslLoss
      loss = preLoss + regLoss + regsslLoss

      self.opt.zero_grad()
      loss.backward()
      self.opt.step()
      if i % args.decay_step == 0:
        self.scheduler.step()

      epochLoss += loss
      epochPreLoss += preLoss
      epochsslLoss += regsslLoss
      epochregLoss += regLoss
      #log('Step %d/%d: loss = %.2f, regLoss = %.2f, sslLoss = %.2f         ' % (i, steps, loss, regLoss, sslLoss), save=False, oneline=True)
    ret = dict()
    ret['Loss'] = epochLoss / steps
    ret['preLoss'] = epochPreLoss / steps
    ret['sslLoss'] = epochsslLoss / steps
    ret['regLoss'] = epochregLoss / steps
    return ret

  def testEpoch(self):
    self.model.eval()
    with torch.no_grad():
      epochRecall, epochNdcg = [0] * 2
      ids = self.handler.tstUsrs
      num = len(ids)
      tstBat = args.batch
      steps = int(np.ceil(num / tstBat))
      tstNum = 0
      ulat, ilat = self.model.forward_test()
      for i in range(steps):
        st = i * tstBat
        ed = min((i+1) * tstBat, num)
        batIds = ids[st: ed]
        trnPosMask = self.handler.trnMat[batIds].toarray()
        toplocs = self.tstPred(batIds, trnPosMask, ulat, ilat)
        recall, ndcg = self.calcRes(toplocs, self.handler.tstLocs, batIds)
        epochRecall += recall
        epochNdcg += ndcg
        #log('Steps %d/%d: recall = %.2f, ndcg = %.2f          ' % (i, steps, recall, ndcg), save=False, oneline=False)
      ret = dict()
      ret['Recall'] = epochRecall / num
      ret['NDCG'] = epochNdcg / num
    return ret

  def tstPred(self, batIds, trnPosMask, ulat, ilat):
    pckUlat = torch.index_select(ulat, 0, torch.Tensor(batIds).int().cuda())
    allPreds = pckUlat @ torch.transpose(ilat, 0, 1)
    allPreds = allPreds.cpu().detach().numpy() * (1 - trnPosMask) - trnPosMask * 1e8
    vals, locs = torch.topk(torch.tensor(allPreds), args.shoot)
    return locs

  def calcRes(self, topLocs, tstLocs, batIds):
    assert topLocs.shape[0] == len(batIds)
    allRecall = allNdcg = 0
    recallBig = 0
    ndcgBig =0
    for i in range(len(batIds)):
      temTopLocs = list(topLocs[i])
      temTstLocs = tstLocs[batIds[i]]
      tstNum = len(temTstLocs)
      maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.shoot))])
      recall = dcg = 0
      for val in temTstLocs:
          if val in temTopLocs:
              recall += 1
              dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
      recall = recall / tstNum
      ndcg = dcg / maxDcg
      allRecall += recall
      allNdcg += ndcg
    return allRecall, allNdcg

  def loadModel(self, loadPath):
    loadPath = loadPath
    checkpoint = torch.load(loadPath)
    self.model = checkpoint['model']
    self.curepoch = checkpoint['epoch']+1
    self.metrics = checkpoint['metrics']

  def saveHistory(self):
    savePath = r'./Model/' + args.data  + r'.pth'
    params = {
        'epoch' : self.curepoch,
        'model' : self.model,
        'metrics' : self.metrics,
    }
    torch.save(params, savePath)

  def run(self):
    self.preparemodel()
    log('Model Prepared')
    if args.load_model != None:
      self.loadModel(args.load_model)
      stloc = self.curepoch
    else:
      stloc = 0

    for ep in range(stloc, args.epoch):
      test = (ep % args.tstEpoch == 0)
      reses = self.trainEpoch()
      #print(self.model.hyperULat_layers[0].fc1.W_fc.weight)
      log(self.makePrint('Train', ep, reses, test))
      if test:
          reses = self.testEpoch()
          log(self.makePrint('Test', ep, reses, test))
      if ep % args.tstEpoch == 0:
          self.saveHistory()
      print()
      self.curepoch = ep
    reses = self.testEpoch()
    log(self.makePrint('Test', args.epoch, reses, True))
    self.saveHistory()

  def makePrint(self, name, ep, reses, save):
    ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
    for metric in reses:
      val = reses[metric]
      ret += '%s = %.4f, ' % (metric, val)
      tem = name + metric
      if save and tem in self.metrics:
          self.metrics[tem].append(val)
    ret = ret[:-2] + '  '
    return ret

if __name__ == '__main__':
  handler = DataHandler()
  handler.LoadData()
  model=sht(handler)
  model.run()

our training parameters:
uEmbed_ini torch.Size([29601, 32]) torch.float32
iEmbed_ini torch.Size([24734, 32]) torch.float32
uHyper torch.Size([128, 32]) torch.float32
iHyper torch.Size([128, 32]) torch.float32
K torch.Size([32, 32]) torch.float32
V torch.Size([32, 32]) torch.float32
user_propagate.0.fc1.W_fc torch.Size([128, 128]) torch.float32
user_propagate.0.fc2.W_fc torch.Size([128, 128]) torch.float32
user_propagate.1.fc1.W_fc torch.Size([128, 128]) torch.float32
user_propagate.1.fc2.W_fc torch.Size([128, 128]) torch.float32
item_propagate.0.fc1.W_fc torch.Size([128, 128]) torch.float32
item_propagate.0.fc2.W_fc torch.Size([128, 128]) torch.float32
item_propagate.1.fc1.W_fc torch.Size([128, 128]) torch.float32
item_propagate.1.fc2.W_fc torch.Size([128, 128]) torch.float32
fc1_label.W_fc torch.Size([64, 32]) torch.float32
fc1_label.bias_fc torch.Size([32]) torch.float32
fc2_label.W_fc torch.Size([32, 1]) torch.float32
fc2_label.bias_fc torch.Size([1]) torch.float32
meta.fc1.W_fc tor

# draft

In [6]:

# handler = DataHandler()
# handler.LoadData()
# adj = handler.trnMat
# idx, data, shape = transToLsts(adj, norm=True)
# adjp = torch.sparse.FloatTensor(idx, data, shape).to(torch.float32).to(args.device)
# idx, data, shape = transToLsts(transpose(adj), norm=True)
# tpadjp = torch.sparse.FloatTensor(idx, data, shape).to(torch.float32).to(args.device)
# model = SHT(adjp, tpadjp).to(args.device)
# opt = torch.optim.Adam(params = model.parameters(), lr=args.lr)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer = opt, gamma=args.decay)
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name,param.shape,param.dtype)
        
# def sampleTrainBatch(batIds, labelMat):
#   temLabel = labelMat[batIds].toarray()
#   batch = len(batIds)
#   temlen = batch * 2 * args.sampNum
#   uLocs = [None] * temlen
#   iLocs = [None] * temlen
#   cur = 0
#   for i in range(batch):
#       posset = np.reshape(np.argwhere(temLabel[i]!=0), [-1])
#       sampNum = min(args.sampNum, len(posset))
#       if sampNum == 0:
#           poslocs = [np.random.choice(args.item)]
#           neglocs = [poslocs[0]]
#       else:
#           poslocs = np.random.choice(posset, sampNum)
#           neglocs = negSamp(temLabel[i], sampNum, args.item)
#       for j in range(sampNum):
#           posloc = poslocs[j]
#           negloc = neglocs[j]
#           uLocs[cur] = uLocs[cur+temlen//2] = batIds[i]
#           iLocs[cur] = posloc
#           iLocs[cur+temlen//2] = negloc
#           cur += 1
#   uLocsa = uLocs[:cur] + uLocs[temlen//2: temlen//2 + cur]
#   iLocsa = iLocs[:cur] + iLocs[temlen//2: temlen//2 + cur]

#   edgeSampNum = int(args.edgeSampRate * args.edgeNum)
#   if edgeSampNum % 2 == 1:
#     edgeSampNum += 1
#   edgeids = np.random.choice(args.edgeNum, edgeSampNum)
  
#   return torch.Tensor(uLocsa).cuda(), torch.Tensor(iLocsa).cuda(), torch.Tensor(edgeids).cuda()
# def trainEpoch():
#   num = args.user
#   sfIds = np.random.permutation(num)[:args.trnNum]
#   epochLoss, epochPreLoss, epochsslLoss, epochregLoss = [0] * 4
#   num = len(sfIds)
#   steps = int(np.ceil(num / args.batch))

#   model.train()
#   for i in range(steps):
#     st = i * args.batch
#     ed = min((i+1) * args.batch, num)
#     batIds = sfIds[st: ed]

#     uLocs, iLocs, edgeids = sampleTrainBatch(batIds, handler.trnMat)
#     preds, sslLoss, regularize = model(uLocs, iLocs, edgeids)

#     sampNum = uLocs.shape[0] // 2
#     posPred = preds[:sampNum]
#     negPred = preds[sampNum:sampNum * 2]
#     preLoss = torch.sum(torch.maximum(torch.Tensor([0.0]).cuda(), 1.0 - (posPred - negPred))) / args.batch
#     regLoss = args.reg * regularize
#     regsslLoss = args.ssl_reg * sslLoss
#     loss = preLoss + regLoss + regsslLoss

#     opt.zero_grad()
#     loss.backward()
#     opt.step()
#     if i % args.decay_step == 0:
#       scheduler.step()

#     epochLoss += loss
#     epochPreLoss += preLoss
#     epochsslLoss += regsslLoss
#     epochregLoss += regLoss
#     #log('Step %d/%d: loss = %.2f, regLoss = %.2f, sslLoss = %.2f         ' % (i, steps, loss, regLoss, sslLoss), save=False, oneline=True)
#   ret = dict()
#   ret['Loss'] = epochLoss / steps
#   ret['preLoss'] = epochPreLoss / steps
#   ret['sslLoss'] = epochsslLoss / steps
#   ret['regLoss'] = epochregLoss / steps
#   return ret
# def calcRes(topLocs, tstLocs, batIds):
#   assert topLocs.shape[0] == len(batIds)
#   allRecall = allNdcg = 0
#   recallBig = 0
#   ndcgBig =0
#   for i in range(len(batIds)):
#     temTopLocs = list(topLocs[i])
#     temTstLocs = tstLocs[batIds[i]]
#     tstNum = len(temTstLocs)
#     maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.shoot))])
#     recall = dcg = 0
#     for val in temTstLocs:
#       if val in temTopLocs:
#         recall += 1
#         dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
#     recall = recall / tstNum
#     ndcg = dcg / maxDcg
#     allRecall += recall
#     allNdcg += ndcg
#   return allRecall, allNdcg

# def tstPred(batIds, trnPosMask, ulat, ilat):
#   pckUlat = torch.index_select(ulat, 0, torch.Tensor(batIds).int().cuda())
#   allPreds = pckUlat @ torch.transpose(ilat, 0, 1)
#   allPreds = allPreds.cpu().detach().numpy() * (1 - trnPosMask) - trnPosMask * 1e8
#   vals, locs = torch.topk(torch.tensor(allPreds), args.shoot)
#   return locs

# def testEpoch():
#   model.eval()
#   with torch.no_grad():
#     epochRecall, epochNdcg = [0] * 2
#     ids = handler.tstUsrs
#     num = len(ids)
#     tstBat = args.batch
#     steps = int(np.ceil(num / tstBat))
#     tstNum = 0
#     ulat, ilat = model.forward_test()
#     for i in range(steps):
#         st = i * tstBat
#         ed = min((i+1) * tstBat, num)
#         batIds = ids[st: ed]
#         trnPosMask = handler.trnMat[batIds].toarray()
#         toplocs = tstPred(batIds, trnPosMask, ulat, ilat)
#         recall, ndcg = calcRes(toplocs, handler.tstLocs, batIds)
#         epochRecall += recall
#         epochNdcg += ndcg
#         #log('Steps %d/%d: recall = %.2f, ndcg = %.2f          ' % (i, steps, recall, ndcg), save=False, oneline=False)
#     ret = dict()
#     ret['Recall'] = epochRecall / num
#     ret['NDCG'] = epochNdcg / num
#   return ret
# def loadModel(loadPath):
#     loadPath = loadPath
#     checkpoint = torch.load(loadPath)
#     model = checkpoint['model']
#     #curepoch = checkpoint['epoch']+1
#     # self.ulat = checkpoint['ulat']
#     # self.ilat = checkpoint['ilat']
#     #metrics = checkpoint['metrics']

# def saveHistory():

#     savePath = r'./Model/' + args.data  + r'.pth'
#     params = {
#         #'epoch' : curepoch,
#         'model' : model,
#         # 'ulat' : self.ulat,
#         # 'ilat' : self.ilat,
#         #'metrics' : metrics,
#     }
#     torch.save(params, savePath)

# def makePrint(name, ep, reses, save):
#   ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
#   for metric in reses:
#     val = reses[metric]
#     ret += '%s = %.4f, ' % (metric, val)
#     tem = name + metric
#     # if save and tem in self.metrics:
#     #   self.metrics[tem].append(val)
#   ret = ret[:-2] + '  '
#   return ret
# def run():
#   #preparemodel()
#   log('Model Prepared')
#   if args.load_model != None:
#     loadModel(args.load_model)
#     #stloc = self.curepoch
#   else:
#     stloc = 0

#   for ep in range(stloc, args.epoch):
#     test = (ep % args.tstEpoch == 0)
#     reses = trainEpoch()
#     #print(self.model.hyperULat_layers[0].fc1.W_fc.weight)
#     log(makePrint('Train', ep, reses, test))
#     if test:
#       reses = testEpoch()
#       log(makePrint('Test', ep, reses, test))
#     if ep % args.tstEpoch == 0:
#       saveHistory()
#     print()
#     curepoch = ep
#   reses = testEpoch()
#   log(makePrint('Test', args.epoch, reses, True))
#   saveHistory()
  
# run()