In [None]:
import sys
sys.path.append("..")

import math
import logging
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
%matplotlib inline

from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler

from scripts.src.hierarchy import *
from scripts.src.processing import *
from scripts.src.label_utils import *
from scripts.src.data_reading import *

from sklearn.metrics import f1_score, precision_score, recall_score, classification_report

logging.basicConfig(level=logging.INFO )

In [None]:
num_gpus = torch.cuda.device_count()
device = torch.device("cuda" if (torch.cuda.is_available() and num_gpus > 0) else "cpu")
device

In [None]:
n_components = 300

In [None]:
class DatasetIterator:
    def __init__(self, datafile, catfile, subsample, is_directed, fmt):
        self.datafile = datafile
        
        if fmt == "libsvm":
            self.lib_data = LIBSVM_Reader(self.datafile, True, n_components, subsample)
            self.df = self.lib_data.data_df
            self.rev_df = self.lib_data.rev_df
        elif fmt =="raw":
            self.raw_df = CSV_Reader(self.datafile, subsample)
            self.df = self.raw_df.data_df
            self.rev_df = self.raw_df.rev_df
        
        self.cat = HierarchyUtils(catfile, [n_components, len(self.df)], is_directed, False)
        self.wn = self.cat.generate_vectors(device = device, neighbours = True)

    def read_df(self, idx):
        i = self.df.index[self.df["doc_id"] == idx][0]
        label = self.df.at[idx, "label_id"]
        W = self.wn[0][label]
        W_pi = self.wn[1][label]
        return label, self.df.at[idx, "doc_id_list"], idx, W, W_pi

    def __getitem__(self, _id):
        return self.read_df(_id)

    def __iter__(self):
        for _id in self.df.index:
            yield self[_id]

class DatasetModule(Dataset):

    def __init__(self, root_location, cat_file, subsample, is_directed, fmt):
        self.iter = DatasetIterator(root_location, cat_file, subsample, is_directed, fmt)
        
        class_labels = self.iter.lib_data.binarizer.classes_
        temp = {}
        for j, i in enumerate(list(class_labels)):
            if i not in temp:
                temp[i] = j

        self.small_mapper = temp
        
        self.lmbda = self.lambda_param()
        
    def lambda_param(self):
        w_n = list2tensor(self.iter.wn[0].values())
        w_pi = list2tensor(self.iter.wn[1].values())

        norm2 = torch.norm(w_n-w_pi, 2)
        lmbda = 0.5*norm2**2
        return lmbda

    def encode_labels(self, labels):
        
        label_vector = torch.zeros((len(self.small_mapper),), device = device, dtype = torch.float32)
        label_vector[self.small_mapper[labels]] = torch.Tensor([1.])
        
        y_in = torch.ones((len(self.small_mapper),), dtype=torch.float32, device = device)*-1
        y_in[self.small_mapper[labels]] = torch.Tensor([1.])
        
        return label_vector, y_in

    def doc_vector(self, doc_ids):
        
        doc_vec = []
        
        for docs in doc_ids:
            ix = self.iter.doc_df.index[self.iter.doc_df["doc_id"] == docs]
            doc_vec.append(self.iter.doc_df.at[ix[0], "doc_vector"])
        
        doc_vector_ = list2tensor(doc_vec)
        
        return doc_vector_

    def __len__(self):
        return len(self.iter.df)

    def __load(self, idx):
        label_ids, doc_id_list, _id, W, W_pi = self.iter[idx]
        return self.doc_vector(doc_id_list), label_ids, _id, self.encode_labels(label_ids), W, W_pi

    def __getitem__(self, idx):
        return self.__load(idx)

In [None]:
class DevsetIterator:
    def __init__(self, datafile, catfile, subsample, is_directed, fmt):
        self.datafile = datafile
        
        if fmt == "libsvm":
            self.lib_data = LIBSVM_Reader(self.datafile, True, n_components, subsample)
            self.doc_df = self.lib_data.data_df
            self.df = self.lib_data.data_df
            self.MLmatrix = self.lib_data.label_matrix
            self.MLbin = self.lib_data.binarizer
        elif fmt =="raw":
            self.csv = CSV_Reader(self.datafile, subsample)
            self.df = self.csv.data_df
        
        self.cat = HierarchyUtils(catfile, [n_components, len(self.df)], is_directed, False)
        self.wn = self.cat.generate_vectors(device = device, neighbours = True)

    def read_df(self, idx):
        i = self.df.index[self.df["doc_id"] == idx][0]
        return self.df.at[i, "doc_vector"], self.df.at[i, "doc_labels"], i
    
    def __getitem__(self, _id):
        return self.read_df(_id)

    def __iter__(self):
        for _id in self.df.index:
            yield self[_id]

class DevsetModule(Dataset):

    def __init__(self, root_location, cat_file, subsample, is_directed, fmt):
        self.iter = DevsetIterator(root_location, cat_file, subsample, is_directed, fmt)
        
        class_labels = sorted(self.iter.MLbin.classes_)
        temp = {}
        for j, i in enumerate(list(class_labels)):
            if int(i) not in temp:
                temp[i] = j

        self.small_mapper = temp
        self.rev_mapper = {v:k for k, v in self.small_mapper.items()}
    
    def __len__(self):
        return len(self.iter.df)

    def encode_labels(self, doc_labels):
        vec = torch.zeros((len(train_data.small_mapper),), device = device, dtype = torch.float32)
        for each_label in doc_labels:
            try:
                vec[self.small_mapper[each_label]] = 1
            except:
                pass
        return vec
    
    def __load(self, idx):
        doc_vec, doc_labels, doc_id = self.iter[idx]
        return doc_vec, doc_labels, doc_id, self.encode_labels(doc_labels)

    def __getitem__(self, idx):
        return self.__load(idx)

In [None]:
train_data = DatasetModule("../rcv1.tar/RCV1_1/rcv1.train.ltc.svm", "../rcv1.tar/RCV1_1/rcv1.topic.hierarchy", True, False, "libsvm")

In [None]:
test_data = DevsetModule("../rcv1.tar/RCV1_1/rcv1.train.ltc.svm", "../rcv1.tar/RCV1_1/rcv1.topic.hierarchy", 0.1, False, "libsvm")

In [None]:
train_node_list = train_data.iter.df["label_id"]

In [None]:
batch_size = 1
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(train_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(test_data, batch_size=batch_size, sampler=valid_sampler)

In [None]:
batch_size = 1
train_loader = DataLoader(train_data, shuffle = True, batch_size=batch_size)
valid_loader = DataLoader(test_data, batch_size=batch_size, shuffle = False)

In [None]:
train_iter = iter(train_loader)
doc_vec, doc_labels, _id, yy, W, W_pi =  train_iter.next()

In [None]:
doc_vec, label_id, _id, yy, W, W_pi =  next(train_iter)

In [None]:
print(doc_vec.squeeze().shape)
print("*"*50)
print(label_id)
print("~"*50)
print(yy[0][0])
print("-"*50)
print(yy[1][0])
print("_"*50)
print(W.shape, W_pi.shape)
print("^"*50)

In [None]:
torch.version.cuda
torch.backends.cudnn.version()
torch.backends.cudnn.benchmark = True

In [None]:
class Node(nn.Module):
    def __init__(self, weight_dims):
        super().__init__()
        w_n = torch.FloatTensor(*weight_dims)
        w_pi = torch.FloatTensor([random.uniform(0.005, 0.009)])

        weights_n = nn.init.xavier_normal_(w_n, gain = nn.init.calculate_gain('relu'))
        weights_pi = nn.init.normal_(w_pi)
        
        self.w = nn.Parameter(weights_n)    
        self.w_pi = nn.Parameter(weights_pi) 
        
    def forward(self, x_i):
        return x_i.matmul(self.w)
    
    def L2_reg(self):
        norm = torch.norm((self.w-self.w_pi), 2)
        return 0.5*norm**2
    
class HRLR(nn.Module):
    def __init__(self, n_components, num_tasks):
        super().__init__()
        self.linear = Node([n_components, num_tasks])

    def forward(self, yin, x_i):
        # for each node, compute forward, and do a -1 +1 threshold to get classes
        score = self.linear.forward(x_i)
        self.fwd_pass = - yin * score
        
        return self.fwd_pass
    
    def compute_loss_leaf(self):
        loss = torch.log2(1+torch.exp(self.fwd_pass))
        l2_reg = self.linear.L2_reg()
        value = loss.mean() 
        return value
    

In [None]:
def reset_model():

    # Hyper Parameters 
    num_epochs = 5
    learning_rate = 0.001 
    n_tasks = len(train_data.iter.df)

    model = HRLR(n_components, n_tasks)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.6)

#     optimizer = torch.optim.LBFGS(model.parameters(), lr=learning_rate)
    
    return model, optimizer, num_epochs

In [None]:
model, optimizer, num_epochs = reset_model()

In [None]:
total_step = len(train_loader)
total_step

In [None]:
torch.cuda.empty_cache()

In [None]:
monitor = {
    "test_f1": [],
    "loss": []
}

In [None]:
for i in range(num_epochs):
    for j, (doc_vec, doc_labels, _id, labelz, _, _) in enumerate(tqdm(train_loader)):
        
        if torch.isnan(list(model.parameters())[0].data.sum()):
            model, optimizer, num_epochs = reset_model()
            continue
        else:
            doc_vec = doc_vec.squeeze()
            yin = labelz[1][0]

            optimizer.zero_grad()
            output = model.forward(yin, doc_vec)
            loss = model.compute_loss_leaf()
            loss.backward()
            optimizer.step()
            monitor["loss"].append(loss.item())
            

In [None]:
plt.plot(monitor["loss"]);

In [None]:
def gather_outputs(data, model):
    logging.info("Evaluating ...")
    yy_t = []
    yy_p = []
    with torch.no_grad():
        for index, (doc_vec, label_ids, doc_id, y_true) in enumerate(tqdm(data)):
            
            doc_vec = doc_vec.squeeze()

            W_params = list(model.parameters())[0].data.squeeze()
            
            score = torch.matmul(doc_vec, W_params)
            score = score.detach().numpy()
            y_index = np.argmax(score)
            
            sc = torch.from_numpy(score)
            sorted_, indices  = torch.sort(sc)
            mid = (sorted_[1:] + sorted_[:-1])/2
            best_thresh, best_f1 = sorted_[0], 0
            
            for threshold in mid:
                y_pred = np.array(sc > threshold).astype(int)
                f1 = f1_score(y_true, y_pred, average="micro")

                if f1 > best_f1:
                    best_thresh = threshold
                    best_f1 = f1
            
            y_pred = np.array(sc > best_thresh).astype(int)
            y_true = y_true.numpy().astype(int)
            
            yy_t.append(y_true)
            yy_p.append(y_pred)
            
    return np.array(yy_t), np.array(yy_p)

y1, y2 = gather_outputs(test_data, model)
# f1_score(y1, y2, average="micro")

In [None]:
print(classification_report(y1, y2))

In [None]:
train_data.iter.lib_data.binarizer.inverse_transform(y1)[:10]

In [None]:
test_data.iter.MLbin.inverse_transform(y2)[:10]