# tf-ProtoNN 
### Implementation of ProtoNN in tensorflow (single-GPU version) for large-scale multilabel learning

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import numpy as np
import scipy.io as sio

from model import *
from preprocess import *

from trainer.single_gpu_train import Trainer
import cfgs.config_eurlex_with_preprocess as config

# Set GPU = 1
config.cfg.num_gpus = 1

### Display dataset and parameter stats

In [2]:
path = os.path.join(config.cfg.dir_name, "train_data")
X, Y = data_loader.data_loader(path, config.cfg)
print("Data-Stats:")
print("Num instances = ", X.shape[0])
print("Feature dimensionality = ", X.shape[-1])
print("Label dimensionality = ", Y.shape[-1])
print("Mean pts per label = ", Y.nnz/Y.shape[1])
print("Mean labels per pt = ", Y.nnz/Y.shape[0])

print("Param-stats:")
print("Projection-Dim: %d"%config.cfg.d)
print("Num Projections: %d"%config.cfg.m)

Data-Stats:
Num instances =  15539
Feature dimensionality =  5000
Label dimensionality =  3993
Mean pts per label =  20.6686701728
Mean labels per pt =  5.31115258382
Param-stats:
Projection-Dim: 250
Num Projections: 1000


### Preprocessing (PCA and clustering) for parameter initialization

In [3]:
tic = time.time()
D = X.shape[-1]
d = config.cfg.d
m = config.cfg.m
W0, Wx = pca.train_pca(X, d)
B0 = clustering.train_kmeans(Wx, m, ngpu = 1).T
Z0 = prototypes.get_prototypes(Y, Wx, B0, num_pts_per_cluster=config.cfg.num_pts_per_cluster)
t_elapsed = time.time() - tic;
print("Time-taken for pre-training: %.4f"%(t_elapsed))

path = os.path.join(config.cfg.dir_name, "init_params_faiss.mat")
sio.savemat(path, {'W':W0, 'B':B0, 'Z':Z0})

Time-taken for pre-training: 72.6670


### Run training

In [4]:
m = Trainer(config.cfg).train()

W variable created on gpu
B variable created on gpu
Z variable created on gpu
Instructions for updating:
Use the retry module or similar alternatives.


Instructions for updating:
Use the retry module or similar alternatives.
2019-01-10 02:57:21,945:INFO:TRAIN-BATCH Iter = 100, t = 6.36, Loss = 555.09, Prec@1: 0.7812, Prec@3: 0.6133, Prec@5: 0.5031
2019-01-10 02:57:22,041:INFO:VAL-ALL Iter = 100, t = 0.09, Loss = 563.51, Prec@1: 0.7445, Prec@3: 0.5794, Prec@5: 0.4712
2019-01-10 02:57:29,914:INFO:TRAIN-BATCH Iter = 200, t = 12.58, Loss = 541.68, Prec@1: 0.8242, Prec@3: 0.6732, Prec@5: 0.5531
2019-01-10 02:57:30,020:INFO:VAL-ALL Iter = 200, t = 0.10, Loss = 548.84, Prec@1: 0.7625, Prec@3: 0.5987, Prec@5: 0.4893
2019-01-10 02:57:37,892:INFO:TRAIN-BATCH Iter = 300, t = 18.78, Loss = 518.53, Prec@1: 0.8555, Prec@3: 0.6875, Prec@5: 0.5328
2019-01-10 02:57:38,004:INFO:VAL-ALL Iter = 300, t = 0.11, Loss = 540.76, Prec@1: 0.7761, Prec@3: 0.6098, Prec@5: 0.4964
2019-01-10 02:57:45,937:INFO:TRAIN-BATCH Iter = 400, t = 25.02, Loss = 496.10, Prec@1: 0.8906, Prec@3: 0.7135, Prec@5: 0.5719
2019-01-10 02:57:46,046:INFO:VAL-ALL Iter = 400, t = 0.10, Lo