## Pytorch Poincare Halfspace embedding for the WordNet Mammals dataset

In [1]:
import timeit
import torch
from tqdm import tqdm
import numpy as np
import logging
from hype.sn import Embedding 
from hype import train
from hype.graph import load_edge_list, eval_reconstruction
from hype.rsgd import RiemannianSGD
from hype.Halfspace import HalfspaceManifold
import sys, os, random
import json
import torch.multiprocessing as mp
from hype.graph_dataset import BatchedDataset
os.environ["NUMEXPR_MAX_THREADS"] = '8'

def seed_everything(seed=1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

d16 = torch.float16; d32 = torch.float32; d64 = torch.float64
cpu = torch.device("cpu"); gpu = torch.device(type='cuda', index=0)
device = cpu
opt_dtype = d64

if opt_dtype == d16:
    dtype = "d16"
    torch.set_default_tensor_type('torch.HalfTensor')
elif opt_dtype == d32:
    dtype = "d32"
    torch.set_default_tensor_type('torch.FloatTensor')
else:
    dtype = "d64"
    torch.set_default_tensor_type('torch.DoubleTensor')
    
torch.manual_seed(42)
np.random.seed(42)

### Hyperparameters for PyTorch Poincare Halfspace model:

In [2]:
## parameters; these are global in the notebook!
opt_maxnorm = 500000; opt_debug = False;
opt_dim = 2; opt_com_n = 1;
opt_negs = 50;  opt_eval_each = 20;
opt_sparse = True; opt_ndproc = 1;  opt_burnin = 20;
opt_dampening = 0.75; opt_neg_multiplier = 1.0; 
opt_burnin_multiplier = 0.01; 
###########################################################
opt_epochs = 1000; opt_batchsize = 32; 
opt_lr = 1.7;  opt_dscale = 0.3
# opt_manifold = "Poincare"
opt_manifold = "Halfspace"
opt_task = 'mammals'
#######################################
FILE_NAME = "_".join([opt_task, 'lr', str(opt_lr), 'batch', str(opt_batchsize),
                      str(opt_epochs), "torch", dtype, str(opt_dscale)])

In [3]:
MANIFOLDS = {
    'Halfspace': HalfspaceManifold,
}

class RES():
    def __init__(self, loss, eval_res, weight):
        self.loss = torch.tensor(loss, dtype=torch.float64, 
                                 device=cpu)
        self.eval_res = torch.tensor(eval_res, dtype=torch.float64, 
                                     device=cpu)
        self.weight = weight

### Initializing logging and data loading

In [4]:
log_level = logging.DEBUG if opt_debug else logging.INFO
log = logging.getLogger('MCF')
logging.basicConfig(level=log_level, format='%(message)s', stream=sys.stdout)
log.info('Using edge list dataloader')
idx, objects, weights = load_edge_list("wordnet/mammal_closure.csv", False) 
#idx, objects, weights = load_edge_list("/home/jl3789/Hyperbolic_Library/applications/poincare_embedding/wordnet/mammal_closure.csv", False) 

Using edge list dataloader


### Initializing model

In [5]:
def init_model(manifold, idx, objects, weights, sparse=True):
    conf = []
    model_name = '%s_dim%d%com_n'
    mname = model_name % (opt_manifold, opt_dim, opt_com_n)
    data = BatchedDataset(idx, objects, weights, opt_negs, opt_batchsize,
        opt_ndproc, opt_burnin > 0, opt_dampening)
    model = Embedding(len(data.objects), opt_dim, manifold, sparse=sparse, com_n=opt_com_n)
    data.objects = objects
    return model, data, mname, conf

def adj_matrix(data):
  adj = {}
  for inputs, _ in data:
    for row in inputs:
        x = row[0].item()
        y = row[1].item()
        if x in adj:
            adj[x].add(y)
        else:
            adj[x] = {y}
  return adj

### Training

In [6]:
def data_loader_lr(data, epoch, progress = False):
  data.burnin = False 
  lr = opt_lr
  if epoch < opt_burnin:
    data.burnin = True
    lr = opt_lr * train._lr_multiplier
  loader_iter = tqdm(data) if progress else data
  return loader_iter, lr

In [7]:
def train(device, model, data, optimizer, progress=False):
    epoch_loss = torch.Tensor(len(data))
    LOSS = np.zeros(opt_epochs)

    for epoch in range(opt_epochs):
        largest_weight_emb = round(
            torch.abs(model.lt.weight.data).max().item(), 6)
        print(largest_weight_emb, "is the largest absolute weight in the embedding")

        epoch_loss.fill_(0)
        t_start = timeit.default_timer()
        # handling burnin, get loader_iter and learning rate
        loader_iter, lr = data_loader_lr(data, epoch, progress=progress)

        for i_batch, (inputs, targets) in enumerate(loader_iter):
            elapsed = timeit.default_timer() - t_start
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            preds = model(inputs) * opt_dscale
            loss = model.loss(preds, targets, size_average=True)
            loss.backward()
            optimizer.step(lr=lr)
            epoch_loss[i_batch] = loss.cpu().item()
            with torch.no_grad():
                loss = model.loss(preds.to(d64), targets, size_average=True)
                epoch_loss[i_batch] = loss.cpu().item()

        LOSS[epoch] = torch.mean(epoch_loss).to(d64).item()
        # since only one thread is used:
        log.info('json_stats: {' f'"epoch": {epoch}, '
                 f'"elapsed": {elapsed}, ' f'"loss": {LOSS[epoch]}, ' '}')
    return LOSS


# Training embedding

In [8]:
# setup model
seed_everything(1)
manifold = MANIFOLDS[opt_manifold](
    debug=opt_debug, max_norm=opt_maxnorm, com_n=opt_com_n)
model, data, model_name, conf = init_model(
    manifold, idx, objects, weights, sparse=opt_sparse)
data.neg_multiplier = opt_neg_multiplier
train._lr_multiplier = opt_burnin_multiplier
model = model.to(device)
print('the total dimension', model.lt.weight.data.size(-1), 'com_n', opt_com_n)
print(">>>>>> # Tensor# | dtype is:", model.lt.weight.dtype,
      "| device is:", model.lt.weight.device)
# setup optimizer
optimizer = RiemannianSGD(model.optim_params(manifold), lr=opt_lr)
# get adjacency matrix
adj = adj_matrix(data)
# begin training
start_time = timeit.default_timer()
loss = train(device, model, data, optimizer, progress=False)
train_time = timeit.default_timer() - start_time
print("Total training time is:", train_time)


>>>>>> The size of embedding: Embedding(1180, 2, sparse=True)
the total dimension 2 com_n 1
>>>>>> # Tensor# | dtype is: torch.float64 | device is: cpu
1.0001 is the largest absolute weight in the embedding
json_stats: {"epoch": 0, "elapsed": 0.2786331550005343, "loss": 3.9317073897809878, }
1.003987 is the largest absolute weight in the embedding
json_stats: {"epoch": 1, "elapsed": 0.26102172900027654, "loss": 3.931372632110913, }
1.007866 is the largest absolute weight in the embedding
json_stats: {"epoch": 2, "elapsed": 0.3030020739997781, "loss": 3.9310328654829845, }
1.011775 is the largest absolute weight in the embedding
json_stats: {"epoch": 3, "elapsed": 0.2371376169994619, "loss": 3.9306918132824, }
1.015683 is the largest absolute weight in the embedding
json_stats: {"epoch": 4, "elapsed": 0.2785311260004164, "loss": 3.9303460968795503, }
1.019605 is the largest absolute weight in the embedding
json_stats: {"epoch": 5, "elapsed": 0.23592588200062892, "loss": 3.93000691107582

# Evaluate embedding

In [9]:
model_weight = model.lt.weight.clone()
meanrank, maprank = eval_reconstruction(adj, model_weight, 
                                        manifold.distance, workers=opt_ndproc)
sqnorms = manifold.pnorm(model_weight)
sqnorm_min = sqnorms.min().item()
sqnorm_avg = sqnorms.mean().item()
sqnorm_max = sqnorms.max().item()
eval_res = [meanrank, maprank, sqnorm_min, sqnorm_avg, sqnorm_max, train_time]
RESULTS = RES(loss, eval_res, model_weight)
torch.save(RESULTS, "./results_weights/"+FILE_NAME+"_seed1"+ ".pt")
log.info(
    'json_stats final test: \n{'
    f'"sqnorm_min": {round(sqnorm_min,6)}, '
    f'"sqnorm_avg": {round(sqnorm_avg,6)}, '
    f'"sqnorm_max": {round(sqnorm_max,6)}, \n'
    f'"mean_rank": {round(meanrank,6)}, '
    f'"map": {round(maprank,6)}, '
    '}'
)
print(model.lt.weight.data[0])

json_stats final test: 
{"sqnorm_min": 0.384906, "sqnorm_avg": 0.998911, "sqnorm_max": 1.0, 
"mean_rank": 1.548318, "map": 0.910576, }
tensor([1.9570e+00, 2.5072e-11])
