In [None]:
import sys
sys.path.append("..")

import math
import heapq
import logging
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
%matplotlib inline

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler

from torchvision import transforms

from hierarchy import *
from processing import *
from label_utils import *
from data_reading import *

from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, classification_report
from tqdm import tqdm_notebook as tqdm 

logging.basicConfig(level=logging.INFO )
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
torch.set_printoptions(precision=5)
torch.set_flush_denormal(False)
num_gpus = torch.cuda.device_count()
device = torch.device("cuda" if (torch.cuda.is_available() and num_gpus > 0) else "cpu")
device

In [None]:
n_components = 300
MAX_VALUE = torch.finfo(torch.float).max

In [None]:
def _check_data(y):

    class_labels = np.unique(y)
    num_tasks = len(class_labels)
    num_examples = y.shape[0]
    if num_tasks == 1:
        raise ValueError("The number of classes has to be greater than one.")
    elif num_tasks == 2:
        if 1 in class_labels and -1 in class_labels:
            num_tasks = 1
            class_labels = np.array([-1, 1])
        elif 1 in class_labels and 0 in class_labels:
            num_tasks = 1
            class_labels = np.array([0, 1])
        else:
            raise ValueError("Unable to decide postive label")

    lbin = LabelBinarizer(neg_label=-1, pos_label=1)
    lbin.fit(class_labels)
    y_bin = lbin.transform(y)
    return y_bin, class_labels, num_tasks

In [None]:
class DatasetIterator:
    def __init__(self, datafile, catfile, subsample, is_directed, fmt, split):
        self.datafile = datafile
        self.cat = HierarchyUtils(catfile, n_components, False)

        if fmt == "libsvm":
            self.lib_data = LIBSVM_Reader(self.datafile, True, n_components, subsample, split)
            self.d_df = self.lib_data.data_df
            self.d_df["doc_labels"] = self.d_df["doc_labels"].apply(lambda x: x[0])
            self.d_df = self.d_df.loc[self.d_df['doc_labels'].isin(self.cat.N_all_nodes)]
            
            self.all_sub_x = self.lib_data.all_x[self.d_df.index, :]
            
            orderer = {}
            for i, ind in enumerate(self.d_df.index):
                orderer[ind] = i
            
            self.df = self.d_df.rename(index=orderer)
            
            self.MLmatrix = self.lib_data.label_matrix
            self.ml_sub_matrix = self.MLmatrix[self.d_df.index, ]
            self.MLbin = self.lib_data.binarizer
            self.r_df = self.lib_data.rev_df
            
        elif fmt =="raw":
            # TODO: add split here too
            self.raw_df = CSV_Reader(self.datafile, subsample)
            self.df = self.raw_df.data_df
            self.r_df = self.raw_df.rev_df
        
#         self.wn = self.cat.generate_vectors(device = device, neighbours = True)

    def read_df(self, idx):
        i = idx
        return self.df.at[i, "doc_id"], self.df.at[i, "doc_vector"], self.df.at[i, "doc_labels"], i
    
    def __getitem__(self, _id):
        return self.read_df(_id)

    def __iter__(self):
        for _id in self.df.index:
            yield self[_id]


class DatasetModule(Dataset):

    def __init__(self, root_location, cat_file, subsample, is_directed, fmt, split):
        
        self.iter = DatasetIterator(root_location, cat_file, subsample, is_directed, fmt, split)

        self.small_mapper = self.iter.cat.node2id
        self.y_bin = generate_binary_yin(self.iter.cat.N_all_nodes)
        self.transform = transforms.Compose([transforms.ToTensor()])

    def encode_labels(self, labels):
        
        y_in = self.y_bin[self.small_mapper[int(labels)]]
        y = torch.as_tensor(y_in>0, dtype=torch.float32, device=device)
            
        return y_in, y
    
    def encode_doc(self, doc_id_list):
        doc_V = torch.empty((n_components, len(doc_id_list)), dtype=torch.float32, device=device)
        for i, doc in enumerate(doc_id_list):
            doc_V[:, i] = self.iter.data_df.at[doc, "doc_vector"]
        return doc_V

    def __len__(self):
        return len(self.iter.df)

    def __load(self, idx):
        doc_id, doc_vec, dec_labels, i = self.iter[idx]
        
        return doc_id, doc_vec, dec_labels

    def __getitem__(self, idx):
        return self.__load(idx)

In [None]:
class TestsetModule(Dataset):
    def __init__(self, location, subsample, fmt, split):
        
        self.location = location
        
        if fmt == "libsvm":
            self.lib_data = LIBSVM_Reader(self.location, True, n_components, subsample, split)
            self.df = self.lib_data.data_df
            
        elif fmt =="raw":
            # TODO: add split here too
            self.raw_df = CSV_Reader(self.datafile, subsample)
            self.df = self.raw_df.data_df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        return self.df.at[idx, "doc_id"], self.df.at[idx, "doc_vector"]

In [None]:
train_data = DatasetModule("../lshtc-small/train.txt", "../lshtc-small/sub_cat_hier.txt", False, True, "libsvm", "test")

In [None]:
valid_data = DatasetModule("../lshtc-small/validation.txt", "../lshtc-small/sub_cat_hier.txt", False, True, "libsvm", "test")

In [None]:
test_data = DatasetModule("../lshtc-small/test.txt", "../lshtc-small/sub_cat_hier.txt", False, True, "libsvm", "test")

In [None]:
'''
batch size affects performance. higher batch size(99) vs. lower(40) for rcv1 didn't converge properly while training
'''

def my_collate(batch):
    
    label_id = [item[0] for item in batch]
    doc_vecs = [item[1] for item in batch]
    
    y_in = [item[2] for item in batch]
    y01 = [item[3] for item in batch]
    
    doc_vecs = torch.cat(doc_vecs, 0).t()
    y_in = torch.cat(y_in, 0)
    y01 = torch.cat(y01, 0)
    
    return [label_id, doc_vecs, y_in, y01]

batch_size = 1
validation_split = 0.01
shuffle_dataset = True
random_seed= 1273

# Creating data indices for training and validation splits:
dataset_size = len(train_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

# train_loader = DataLoader(train_data, batch_size= batch_size, shuffle=True)
train_loader = DataLoader(train_data, shuffle=True, batch_size=1)
validation_loader = DataLoader(valid_data, shuffle=True, batch_size=1)
# validation_loader = DataLoader(train_data, sampler=valid_sampler, batch_size=1)
test_loader = DataLoader(test_data, batch_size=1, shuffle = False)

In [None]:
len(train_loader)

In [None]:
torch.version.cuda
torch.backends.cudnn.version()
torch.backends.cudnn.benchmark = True

In [None]:
class RRLoss(nn.Module):
    def __init__(self, n_node):
        super().__init__()
        self.n_node = n_node
        self.H = train_data.iter.cat
        
    def forward(self, n_node_vec, pi_node_vec):

        w_n, w_pi = n_node_vec, pi_node_vec
        w_pi_no_grad = w_pi.detach()

        param = torch.norm(w_n - w_pi_no_grad, 2)
        param = (1e-6) * 0.5*param**2
                
        norm_sq = param

        return norm_sq
    
    def non_leaf_update(self, Cn, w_pi_n, w_c):

        W = (1/(len(Cn)+1)) * (w_pi_n + w_c)
        return W
    
    
class Node(nn.Module):
    def __init__(self, node_n):
        super().__init__()

        self.node_n = node_n
        
#         w_n = train_data.iter.cat.W_nodes[self.node_n]
        w_ = [n_components, 1]
        w_n = torch.FloatTensor(*w_)
        weights_n = nn.init.kaiming_normal_(w_n, mode="fan_out")
        
        self.weight = nn.Parameter(weights_n)    
        
    def forward(self, x_i):
        
        out = x_i.mm(self.weight)
#         out = F.sigmoid(out)      
        return out
    
class HRLR(nn.Module):
    def __init__(self, node_n):
        super().__init__()
        self.linear = Node(node_n)

    def forward(self, x_i):
        score = self.linear.forward(x_i)
        return score
    
    def compute_loss_leaf(self, yin, output):
        
        m = nn.Sigmoid()
        fwd_pass = -yin * output
        exp = torch.exp(fwd_pass)
#         loss = m(fwd_pass)
        
        if torch.isinf(exp).sum():
            loss_inf = MAX_VALUE/1e30
            exp = torch.tensor(loss_inf, device=device, dtype=torch.float32, requires_grad=True) #.expand(exp.shape)
            
        loss = torch.log1p(exp)
        value = torch.sum(loss) 
        return value

In [None]:
def non_leaf_update(non_leaf_, main_path):
    H = train_data.iter.cat
    node2id = H.node2id
    id2node = H.id2node
    
    odd_even = OrderedDict()
    
    root = []
    for nl in non_leaf_:
        if nl not in H.child2parent_table:
            root.append(nl) 
    non_leaf = set(non_leaf_) - set(root)
    
    
    # internal nodes
    for r in root:
        r_ = node2id[r]
        for n in non_leaf:
            n_ = node2id[n]
            oe = len(H.get_shortest_path(r_, n_)[0])
            if oe == 0 :
                pass
            else:
                if oe not in odd_even:
                    odd_even[oe] = [n_]
                else:
                    odd_even[oe].append(n_)
        
        
    logging.info("at even level..")
    
    for level, nodes in odd_even.items():
        
        if level%2 == 0:
            
            for k in tqdm(nodes):
                
                rr = RRLoss(k)
                kth_path = main_path.format(k)
                _, kth_model, kth_opt, kth_loss = load_nth_model(kth_path)
                
                
                Cn = H.parent2child_table[id2node[k]]
                pi = H.child2parent_table[id2node[k]][0]

                pi = node2id[pi]
                pi_path = main_path.format(pi)
                _, pi_model, _, _ = load_nth_model(pi_path)

                C_n = 0
                
                for x in Cn: 
                    cn_path = main_path.format(node2id[x])
                    _, cn_model, _, _ = load_nth_model(cn_path)
                    C_n += cn_model.state_dict()['linear.weight'].data
                
                w_pi_n = pi_model.state_dict()['linear.weight'].data
                w_c = C_n
                
                with torch.no_grad():
                    for name, param in kth_model.state_dict().items():
                        param = rr.non_leaf_update(Cn, w_pi_n, w_c)
                        kth_model.state_dict()[name] = param
                        save_nth_models(k, kth_model, kth_opt, kth_loss, kth_path)
    
    
    logging.info("at odd level..")
    
    for level, nodes in odd_even.items():
    
        if level%2 == 1:
            
            for k in tqdm(nodes):
                
                rr = RRLoss(k)
                kth_path = main_path.format(k)
                _, kth_model, kth_opt, kth_loss = load_nth_model(kth_path)
                
                
                Cn = H.parent2child_table[id2node[k]]
                pi = H.child2parent_table[id2node[k]][0]

                pi = node2id[pi]
                pi_path = main_path.format(pi)
                _, pi_model, _, _ = load_nth_model(pi_path)

                C_n = 0
                
                for x in Cn: 
                    cn_path = main_path.format(node2id[x])
                    _, cn_model, _, _ = load_nth_model(cn_path)
                    C_n += cn_model.state_dict()['linear.weight'].data
                
                w_pi_n = pi_model.state_dict()['linear.weight'].data
                w_c = C_n
                
                with torch.no_grad():
                    for name, param in kth_model.state_dict().items():
                        param = rr.non_leaf_update(Cn, w_pi_n, w_c)
                        kth_model.state_dict()[name] = param
                        save_nth_models(k, kth_model, kth_opt, kth_loss, kth_path)        
#     return odd_even

In [None]:
def reset_model(node_n):
    '''
    training performance is affected by large `n_tasks` size which basically increases the number of parameters to tune.
    as parameter dimension increases, weight decay also needs to be increased. what is the relation between param dim 
    and weight decay?
    '''

    # Hyper Parameters 
    d_dim = n_components
    num_classes = len(train_data.iter.MLbin.classes_) #batch_size

    model = HRLR(node_n)
    model = model.to(device)
    
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    optimizer = torch.optim.LBFGS(params=model.parameters(), max_iter=20, history_size=20)
    
    return model, optimizer

In [None]:
def save_nth_models(n_node, model, opt, loss, path):    
    
    save_dict = {'n_node': n_node,
               'model': model.state_dict(),
               'optim': opt.state_dict(),
               'n_loss': loss }
        
    torch.save(save_dict, path)

In [None]:
def load_nth_model(path):
    
    checkpoint = torch.load(path)
    n_node = checkpoint['n_node']
    loss = checkpoint['n_loss']
    
    model, optimizer = reset_model(n_node)
    
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optim'])
    
    return n_node, model, optimizer, loss

In [None]:
total_step = len(train_loader)
total_step

In [None]:
leaves = train_data.iter.cat.T_leaves
internal = train_data.iter.cat.pi_parents
node2id = train_data.iter.cat.node2id
id2node = train_data.iter.cat.id2node
classes = list(node2id.values())

In [None]:
len(classes)

In [None]:
def create_node_models():
    
    g = {}
    nth_model = {}
    nth_optim = {}
    nth_loss = {}
    
    for c_ids in tqdm(classes):
        
        torch.cuda.empty_cache()        
        
        nth_loss[c_ids] = []
#         nth_model[c_ids], nth_optim[c_ids] = reset_model(c_ids)
        
        pre_ = "../persist_models_lshtc/"

        if not os.path.isdir(pre_):
            os.mkdir(pre_)

        path = "{}_all_classes_{}.pkl".format(pre_, c_ids)
        
        if not os.path.isfile(path):
            model, optimizer = reset_model(c_ids)
            save_nth_models(c_ids, model, optimizer, nth_loss[c_ids], path)


create_node_models()

In [None]:
len(internal) + len(leaves) == len(classes)

In [None]:
def model_train(data, main_path):
    
    pat = []
    per_node_loss = {}
    cv = [74, 77, 79]
    
    doc_vec_all = torch.as_tensor(train_data.iter.all_sub_x, device=device, dtype=torch.float32)
    
    for orig in tqdm(leaves):
        
#         orig = id2node[orig]
        n_id = node2id[orig]
        pi_n =  node2id[train_data.iter.cat.child2parent_table[orig][0]]
        pat.append(n_id)
        
        
        leaf_path = main_path.format(n_id)
        pi_path = main_path.format(pi_n)
        
        skipper = os.path.getsize(leaf_path)
        
        if skipper > 843537:
            logging.info("Done: Class: {}, node_id: {}".format(orig, n_id))
            _, current_leaf_model, current_optimizer, per_node_loss[n_id] = load_nth_model(leaf_path)
            pass
        
        else:
        
            logging.info("Leaf: Class: {}, node_id: {}".format(orig, n_id))

            _, current_leaf_model, current_optimizer, per_node_loss[n_id] = load_nth_model(leaf_path)
            _, parent_model, _, _ = load_nth_model(pi_path)

            torch.cuda.empty_cache()        

            for j ,(docid, doc_vec, doc_labels) in enumerate(data):
#                 yin = torch.zeros((1,len(classes)), dtype= torch.float32, device=device)
#                 yin[:, node2id[int(doc_labels)]] = 1
                
                if orig == doc_labels:
                    y_in = 1
                else:
                    y_in = -1
                
                C = 1e-3

                def closure():
                    current_optimizer.zero_grad()

                    w_n = current_leaf_model.state_dict()['linear.weight'].data
                    pi__n = parent_model.state_dict()['linear.weight'].data

                    rr = RRLoss(n_id)
                    L2 = rr.forward(w_n, pi__n)

                    output = current_leaf_model.forward(doc_vec)

                    loss = current_leaf_model.compute_loss_leaf(y_in, output)
                    loss = C * loss
                    loss.add_(L2)
                    loss.backward()
                    per_node_loss[n_id].append(loss)
                    return loss
                current_optimizer.step(closure)

            save_nth_models(n_id, current_leaf_model, current_optimizer, per_node_loss[n_id], leaf_path)
            
    non_leaf_update(internal, main_path)
    
    return per_node_loss

In [None]:
main_path = "../persist_models_lshtc/_all_classes_{}.pkl"
nth_loss = model_train(train_loader, main_path)

In [None]:
f, ax = plt.subplots(nrows=5, ncols=5, figsize=(15,15))

for a, (key, value) in zip(ax.flatten(), nth_loss.items()):
    a.plot(value)
#     a.yaxis.set_ticklabels([])
    a.set_title("n_id: {}".format(key))

In [None]:
with torch.no_grad():
    m_W = torch.zeros((len(classes), n_components))
    for n_node in tqdm(classes):

#         n_node = node2id[orig]
        node_path = main_path.format(n_node)

        _, current_node_model, _, _ = load_nth_model(node_path)
        w_n = current_node_model.state_dict()['linear.weight'].data
        m_W[n_node, :] = w_n.squeeze()
        torch.cuda.empty_cache()        
m_W = m_W.t()

In [None]:
theta_th = []
for it, (docid, doc_vec, doclabels) in enumerate(tqdm(train_loader)):
        nid = node2id[int(doclabels)]
        m_W = m_W.to(device)
        score = (doc_vec.mm(m_W))
        score = (score.cpu().numpy())
        theta_th.append(score[0, nid])
        
for it, (docid, doc_vec, doclabels) in enumerate(tqdm(validation_loader)):
        nid = node2id[int(doclabels)]
        m_W = m_W.to(device)
        score = (doc_vec.mm(m_W))
        score = (score.cpu().numpy())
        theta_th.append(score[0, nid])

In [None]:
len(theta_th)

In [None]:
def y_validate(data, m_W, theta_th):
    
    # for each document representation
    # for each node classifier N
    # calucate score per classifier for document:
    # score[n] = model[n].weight.t().mm(doc_vec) -> 1x1 vector
    # label will be: max value at score[n] -> n 
    
    y_true_num = np.zeros((len(data), len(classes)))
    y_pred = np.zeros((len(data), len(classes)))
    
    for it, (docid, doc_vec, doclabels) in enumerate(tqdm(data)):
        
        nid = node2id[int(doclabels)]
        y_true_num[it, nid] = 1
        
        m_W = m_W.to(device)
        
        score = F.normalize(doc_vec.mm(m_W))
        score = (score.cpu().numpy().squeeze().tolist())
        print(nid, score[nid])
#         for j, i in enumerate(score):
#             if i < 0.03 :
#                 y_pred[it, j] = 1
            

    return y_pred, y_true_num

In [None]:
yp_, yt_ = y_validate(validation_loader, m_W, theta_th)

In [None]:
yp_.sum()

In [None]:
precision_score(yt_, yp_, average="micro")

In [None]:
def y_test_predict(data, nth_model):
    
    # predictions for test set
    model_W = []
    
    for n_node, n_model in tqdm(nth_model.items()):
        model_W.append(n_model.state_dict()['linear.weight'].data)
    
    model_W = torch.stack(model_W, 0).squeeze()
    
    
    yp = []
    for _, doc_vec in tqdm(data):
        
        doc_vec = torch.as_tensor(doc_vec, dtype=torch.float32, device=device)
        
        score = doc_vec.mm(model_W.t())
        temp = score.cpu().numpy()
        batch_ = np.array(temp > 500).astype(int)

        yp.append(batch_)
    
    y_p = np.vstack(yp)
    
    del yp
    
    return y_p

In [None]:
y_test_p = y_test_predict(test_loader, nth_model)

In [None]:
y_test_p.shape

In [None]:
y_test_true_dict = {}
lim = y_test_p.shape[0]

with open("../DMOZ/DMOZGS", "r") as test_true:
    ans = test_true.readlines()
    
for i, line in enumerate(ans[:lim]):
    each_line = line.strip().split(' ')
    int_el = list(map(int ,each_line))
    y_test_true_dict[i] = int_el
    
y_test_true_np = np.zeros_like(y_test_p)

for i, true_node in tqdm(y_test_true_dict.items()):
    for t in true_node:
        class_id = node2id[t]
        y_test_true_np[i, class_id] = 1

In [None]:
y_test_true_np.sum()

In [None]:
y_test_p.sum()

In [None]:
y_test_true_np.sum()

In [None]:
print(classification_report(y_test_true_np, y_test_p))