In [1]:
import os
from os.path import join
import numpy as np
import pandas as pd
from scipy import sparse
import pickle
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from utils import *
from tqdm import tqdm

In [2]:
class TextAndNearestNeighborsDataset(Dataset):

    def __init__(self, dataset_name, data_dir, subset='train'):
        """
        Args:
            data_dir (string): Directory for loading and saving train, test, and cv dataframes.
            subset (string): Specify subset of the datasets. The choices are: train, test, cv.
        """
        self.dataset_name = dataset_name
        self.data_dir = os.path.join(data_dir, dataset_name)
        self.subset = subset
        fn = '{}.{}.pkl'.format(dataset_name, subset)
        self.df = self.load_df(self.data_dir, fn)
        self.docid2index = {docid: index for index, docid in enumerate(list(self.df.index))}
        
        if dataset_name in ['reuters', 'rcv1', 'tmc']:
            self.single_label = False
        else:
            self.single_label = True

    def load_df(self, data_dir, df_file):
        df_file = os.path.join(data_dir, df_file)
        return pd.read_pickle(df_file)
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        doc_id = self.df.iloc[idx].name
        doc_bow = self.df.iloc[idx].bow
        doc_bow = torch.from_numpy(doc_bow.toarray().squeeze().astype(np.float32))
        
        label = self.df.iloc[idx].label
        label = torch.from_numpy(label.toarray().squeeze().astype(np.float32))
                
        neighbors = torch.LongTensor(self.df.iloc[idx].neighbors)
        return (doc_id, doc_bow, label, neighbors)
    
    def num_classes(self):
        return self.df.iloc[0].label.shape[1]
    
    def num_features(self):
        return self.df.iloc[0].bow.shape[1]

In [3]:
dataset_name = 'reuters'

train_set = TextAndNearestNeighborsDataset('reuters', 'dataset/clean', 'train')
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=100, shuffle=False)
test_set = TextAndNearestNeighborsDataset('reuters', 'dataset/clean', 'test')
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=100, shuffle=False)

In [4]:
def BFS_walk(df, start_node_id, num_steps, max_branch_factor=50):
    if isinstance(start_node_id, list):
        queue = list(start_node_id)
    else:
        queue = [start_node_id]
        
    visited_nodes = set()
    curr_step = 0
    while len(queue) > 0:         
        curr_node_id = queue.pop(0)
        while curr_node_id in visited_nodes:
            if len(queue) <= 0:
                #if not isinstance(start_node_id, list):
                #    visited_nodes.remove(start_node_id)
                return list(visited_nodes)
            curr_node_id = queue.pop(0)
        
        nn_list = list(train_set.df.loc[curr_node_id].neighbors[:max_branch_factor])
        #np.random.shuffle(nn_list)
        queue += nn_list
        visited_nodes.add(curr_node_id)
        curr_step += 1
        if curr_step > num_steps:
            break
    
    #if not isinstance(start_node_id, list):
    #    visited_nodes.remove(start_node_id)
    return list(visited_nodes)    

walk_type, max_nodes = 'BFS', 50
max_nodes = int(max_nodes)
print("Walk type: {} with maximum nodes of: {}".format(walk_type, max_nodes))

if walk_type == 'BFS':
    neighbor_sample_func = BFS_walk
elif walk_type == 'DFS':
    neighbor_sample_func = DFS_walk
elif walk_type == 'Random':
    neighbor_sample_func = Random_walk
else:
    neighbor_sample_func = None
    print("The model will only takes the immediate neighbors.")
    #assert(False), "unknown walk type (has to be one of the following: BFS, DFS, Random)"

def get_neighbors(ids, df, max_nodes, batch_size, traversal_func):
    cols = []
    rows = []
    for idx, node_id in enumerate(ids):
        nn_indices = traversal_func(df, node_id.item(), max_nodes)
        col = [train_set.docid2index[v] for v in nn_indices]
        rows += [idx] * len(col)
        cols += col
    data = [1] * len(cols)
    connections = sparse.csr_matrix((data, (rows, cols)), shape=(batch_size, len(df)))
    return torch.from_numpy(connections.toarray()).type(torch.FloatTensor)

Walk type: BFS with maximum nodes of: 50


In [5]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
y_dim = train_set.num_classes()
num_bits = 8
num_features = train_set[0][1].size(0)
num_nodes = len(train_set)


In [7]:
from model.EdgeReg import *

model = EdgeReg(dataset_name, num_features, num_nodes, num_bits, dropoutProb=0.1, device=device)
model.to(device)

EdgeReg(
  (encoder): Sequential(
    (0): Linear(in_features=10000, out_features=1000, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU(inplace)
    (4): Dropout(p=0.1)
  )
  (h_to_mu): Linear(in_features=1000, out_features=8, bias=True)
  (h_to_logvar): Sequential(
    (0): Linear(in_features=1000, out_features=8, bias=True)
    (1): Sigmoid()
  )
  (decoder): Sequential(
    (0): Linear(in_features=8, out_features=10000, bias=True)
    (1): LogSoftmax()
  )
  (nn_decoder): Sequential(
    (0): Linear(in_features=8, out_features=7763, bias=True)
    (1): LogSoftmax()
  )
)

In [8]:
optimizer = optim.Adam(model.parameters(), lr=0.01)
kl_weight = 0.
kl_step = 1 / 5000.

best_precision = 0
best_precision_epoch = 0

edge_weight = 10.
edge_step = 1 / 1000.

for epoch in range(20):
    avg_loss = []
    for ids, xb, yb, nb in tqdm(train_loader, ncols=50):
        xb = xb.to(device)
        yb = yb.to(device)

        nb = get_neighbors(ids, train_set.df, max_nodes, xb.size(0), neighbor_sample_func)
        nb = nb.to(device)

        logprob_w, logprob_nn, mu, logvar = model(xb)
        kl_loss = EdgeReg.calculate_KL_loss(mu, logvar)
        reconstr_loss = EdgeReg.compute_reconstr_loss(logprob_w, xb)
        nn_reconstr_loss = EdgeReg.compute_edge_reconstr_loss(logprob_nn, nb)

        loss = reconstr_loss + edge_weight * nn_reconstr_loss + kl_weight * kl_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        kl_weight = min(kl_weight + kl_step, 1.)
        edge_weight = min(edge_weight + edge_step, 1.)
        avg_loss.append(loss.item())
        
    with torch.no_grad():
        train_b, test_b, train_y, test_y = model.get_binary_code(train_loader, test_loader)
        retrieved_indices = retrieve_topk(test_b.to(device), train_b.to(device), topK=100)
        prec = compute_precision_at_k(retrieved_indices, test_y.to(device), train_y.to(device), topK=100)
        print("precision at 100: {:.4f}".format(prec.item()))

        if prec.item() > best_precision:
            best_precision = prec.item()
            best_precision_epoch = epoch + 1

        print('{} epoch:{} loss:{:.4f} Best Precision:({}){:.3f}'.format(model.get_name(), epoch+1, np.mean(avg_loss), best_precision_epoch, best_precision))


100%|█████████████| 78/78 [01:29<00:00,  1.14s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6618
node2hash epoch:1 loss:542.3105 Best Precision:(1)0.662


100%|█████████████| 78/78 [01:21<00:00,  1.05s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.7203
node2hash epoch:2 loss:405.1187 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:19<00:00,  1.02s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.7145
node2hash epoch:3 loss:382.8038 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:16<00:00,  1.02it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6988
node2hash epoch:4 loss:375.1534 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:15<00:00,  1.03it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6872
node2hash epoch:5 loss:371.6333 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:16<00:00,  1.02it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6904
node2hash epoch:6 loss:369.6670 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:11<00:00,  1.09it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.7112
node2hash epoch:7 loss:368.0412 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:15<00:00,  1.03it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6926
node2hash epoch:8 loss:367.9449 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:14<00:00,  1.04it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6991
node2hash epoch:9 loss:367.3656 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:15<00:00,  1.03it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6955
node2hash epoch:10 loss:367.2956 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:19<00:00,  1.02s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6745
node2hash epoch:11 loss:367.5660 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:19<00:00,  1.02s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6831
node2hash epoch:12 loss:367.8835 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:13<00:00,  1.06it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6809
node2hash epoch:13 loss:368.5842 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:13<00:00,  1.06it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6893
node2hash epoch:14 loss:368.9242 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:11<00:00,  1.09it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6784
node2hash epoch:15 loss:370.0969 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:18<00:00,  1.01s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6844
node2hash epoch:16 loss:369.5918 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:14<00:00,  1.04it/s]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6820
node2hash epoch:17 loss:369.3615 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:23<00:00,  1.07s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6721
node2hash epoch:18 loss:369.8176 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:25<00:00,  1.10s/it]
  0%|                      | 0/78 [00:00<?, ?it/s]

precision at 100: 0.6868
node2hash epoch:19 loss:370.2511 Best Precision:(2)0.720


100%|█████████████| 78/78 [01:27<00:00,  1.12s/it]
                          

precision at 100: 0.6871
node2hash epoch:20 loss:370.2210 Best Precision:(2)0.720


