In [1]:
import torch
import numpy as np
import math

## Config and Data Generators

In [2]:
class config:
    data_path_train = 'dataset/Amazon/amazon_train.txt'
    data_path_test = 'dataset/Amazon/amazon_test.txt'
    GPUs = True
    lr = 0.0001
    max_l2 = 6
    sparsity = 0.00001
    feature_dim = 135909
    n_classes = 670091
    n_train = 490449
    n_test = 153025
    n_epochs = 20
    batch_size = 256
    test_batch_size = 256
    hidden_dim = 128
    log_file = 'log_amz_torch_latest'

In [3]:
def get_default_device():
    # Pick GPU if available else CPU
    if torch.cuda.is_available() and config.GPUs:
        return torch.device("cuda")
    else:
        return torch.device("cpu")

In [4]:
import os
# If the runtime is connected to Colab Hosted runtime
if "COLAB_GPU" in os.environ:
    config.data_path_train = '/content/drive/MyDrive/Colab Datasets/Amazon/amazon_train.txt'
    config.data_path_test = '/content/drive/MyDrive/Colab Datasets/Amazon/amazon_test.txt'
    from google.colab import drive
    drive.mount('/content/drive')

In [5]:
device = get_default_device()

In [6]:
device

device(type='cuda')

In [7]:
device = torch.device("cuda:0")

In [8]:
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

In [9]:
batch_size = config.batch_size
n_classes = config.n_classes

In [10]:
from itertools import islice

In [11]:
def data_generator(file_name, batch_size, n_classes):
    while True:
        lines = []
        with open(file_name,'r',encoding='utf-8') as f:
            header = f.readline()  # ignore the header
            while True:
                temp = len(lines)
                lines += list(islice(f,batch_size-temp))
                if len(lines)!=batch_size:
                    break
                idxs_x, idxs_y = [], []
                vals = []
                y_idxs = []
                labels_batch = []
                y_batch = torch.zeros([batch_size, n_classes], dtype = torch.float32, device = device)
                count = 0
                for line in lines:
                    itms = line.strip().split(' ')
                    y_idxs = [int(itm) for itm in itms[0].split(',')]
                    labels_batch.append(y_idxs)
                    y_batch[count, y_idxs] = 1.0/len(y_idxs)
                    temp_idxs_y = [int(itm.split(':')[0]) for itm in itms[1:]]
                    temp_idxs_y.append(config.feature_dim)
                    idxs_y += temp_idxs_y
                    idxs_x += [count] * len(temp_idxs_y)
                    vals += [float(itm.split(':')[1]) for itm in itms[1:]]
                    vals.append(1.0)
                    count += 1
                lines = []
                yield (idxs_x, idxs_y, vals, y_batch, labels_batch)

In [12]:
def data_generator_tst(file_name, batch_size, n_classes):
    while True:
        lines = []
        with open(file_name,'r',encoding='utf-8') as f:
            header = f.readline()  # ignore the header
            while True:
                temp = len(lines)
                lines += list(islice(f,batch_size-temp))
                if len(lines)!=batch_size:
                    break
                idxs_x, idxs_y = [], []
                vals = []
                labels_batch = []
                count = 0
                for line in lines:
                    itms = line.strip().split(' ')
                    y_idxs = [int(itm) for itm in itms[0].split(',')]
                    labels_batch.append(y_idxs)
                    temp_idxs_y = [int(itm.split(':')[0]) for itm in itms[1:]]
                    temp_idxs_y.append(config.feature_dim)
                    idxs_y += temp_idxs_y
                    idxs_x += [count] * len(temp_idxs_y)
                    vals += [float(itm.split(':')[1]) for itm in itms[1:]]
                    vals.append(1.0)
                    count += 1
                lines = []
                yield (idxs_x, idxs_y, vals, labels_batch)

## Torch Training

In [13]:
!pip install hnswlib -q

In [14]:
import torch
import time
import numpy as np
import hnswlib
import math
import torch.nn as nn

In [15]:
n_epochs = config.n_epochs
n_train = config.n_train
n_test = config.n_test
n_check = 50
steps_per_epoch = n_train // batch_size
n_steps = config.n_epochs * steps_per_epoch

In [16]:
train_data_generator = data_generator(config.data_path_train, batch_size = config.batch_size, n_classes = config.n_classes)

In [17]:
W1 = torch.randn(config.feature_dim + 1, config.hidden_dim, requires_grad = False)
W2 = torch.randn(config.hidden_dim + 1, config.n_classes, requires_grad = False)

In [18]:
val = 2.0/math.sqrt(config.hidden_dim + 1 + config.n_classes)
a = - 2 * val
b = 2 * val
W1 = nn.init.trunc_normal_(W1, a = a, b = b)
W2 = nn.init.trunc_normal_(W2, a = a, b = b)

In [19]:
config.max_l2 = 3 * 11 * torch.max(torch.abs(W2)).item()

In [20]:
config.max_l2

0.161237183958292

In [21]:
with torch.no_grad():
    (W1, W2) = to_device((W1, W2), device)

In [22]:
W1.grad = W1.new_zeros(W1.shape)
W2.grad = W2.new_zeros(W2.shape)

In [23]:
W1.grad

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [24]:
print(W1.requires_grad, W2.requires_grad)

False False


In [25]:
adam_optim = torch.optim.Adam(params = (W1, W2), lr = config.lr)

log_softmax = torch.nn.LogSoftmax(dim = 1)
add_unity_col = torch.nn.ConstantPad1d((0, 1), value = 1.0)
add_zero_col = torch.nn.ConstantPad1d((0, 1), value = 0.0)

In [26]:
dim = (config.hidden_dim + 1) + 1
max_elements = int(2 * config.n_classes)
M = 24 
ef_construction = 100
ef_search = int(2 * config.sparsity * config.n_classes)

In [27]:
def add_items_to_hnsw_index(index, data, ids):
    with torch.no_grad():
        temp_items = torch.zeros(size = (data.shape[1], index.dim))
        temp_items[:,:-1] = data.T / config.max_l2
        temp_items[:, -1] = torch.sqrt(1.00 - torch.sum(temp_items.pow(2), dim = 1))
    
        index.add_items(data = temp_items, ids = ids)
  
def query_items_from_hnsw_index(index, query, k):
    with torch.no_grad():
        aug_query = add_zero_col(query)
        aug_query = torch.nn.functional.normalize(aug_query, p = 2, dim = 1)
        return index.knn_query(data = aug_query, k = k)

In [28]:
def rebuild_index(data, ids):
    global last_index_saved
    index = hnswlib.Index(space = "ip", dim = dim)
    index.init_index(max_elements = int(2 * config.n_classes), M = M, ef_construction = ef_construction)
    index.set_ef(ef = ef_search)
    config.max_l2 = 3 * 11 * torch.max(torch.abs(W2)).item()
    add_items_to_hnsw_index(index, data, ids)
    last_index_saved = np.arange(config.n_classes)
    print("Index Rebuilt", config.max_l2)
    return index

In [29]:
begin_time = time.time()
index = rebuild_index(W2, np.arange(config.n_classes))
print(time.time() - begin_time)

Index Rebuilt 0.161237183958292
43.12212252616882


In [30]:
init_rebuild_freq = 50
decay_rate = 0.003

In [31]:
rebuild_freq = init_rebuild_freq

In [None]:
# rebuild_frequencies = [50, 70, 100, 150, 180, 200, 220, 250, 3]

In [32]:
last_rebuild_step = 0

In [33]:
last_index_saved = np.arange(config.n_classes)
error_instances = 0

In [34]:
def fit(train_dg, step):
    with torch.no_grad():
        global last_index_saved, index, error_instances, last_rebuild_step, rebuild_freq
        try:
            adam_optim.zero_grad()
            idxs_x, idxs_y, vals, Y, labels = next(train_dg)
            
            # Feed Forward
            input = to_device(torch.sparse_coo_tensor([idxs_x, idxs_y], vals, size = (batch_size, config.feature_dim + 1)), device)
            A1 = torch.sparse.mm(input, W1)
            A1 = add_unity_col(A1)
            Z1 = torch.nn.functional.relu(A1)
            # ## HNSW layer2 indices query logic
            layer2_idxs, _ = query_items_from_hnsw_index(index, query = Z1.cpu(), k = int(config.sparsity * config.n_classes))
            layer2_idxs = layer2_idxs.flatten() % config.n_classes
            labels = np.array([x for sub in labels for x in sub])
            layer2_idxs = np.union1d(layer2_idxs, labels).astype(np.int32)

            #  Sparse Feed Forward
            A2 = Z1 @ W2[:, layer2_idxs]
            P = log_softmax(A2)
            L = -P * Y[:, layer2_idxs]
            L = torch.mean(torch.sum(L, dim = 1))

            # Sparse Back Propagation
            PL = torch.exp(P)
            temp_B2_grad = (PL - Y[:, layer2_idxs])
            W2.grad[:, layer2_idxs] = Z1.T @ temp_B2_grad
            temp_B1_grad = temp_B2_grad @ W2[:, layer2_idxs].T
            temp_B1_grad[A1 < 0] = 0
            W1.grad = torch.sparse.mm(input.t(), temp_B1_grad[:, :-1])  

            adam_optim.step()

            if (step - last_rebuild_step) >= rebuild_freq:
                index = rebuild_index(W2, np.arange(config.n_classes))
                last_rebuild_step = step
                rebuild_freq = int(init_rebuild_freq * math.exp(decay_rate * step))
                print("New rebuild frequency: ", rebuild_freq)

        except Exception as e:
            print(e)
            error_instances += 1
            if error_instances == 3:
                index = rebuild_index(W2, np.arange(config.n_classes))
        finally:
            return 1

In [35]:
def evaluate(n_steps, test_dg):
    accuracies = []
    for h in range(n_steps):
        idxs_x, idxs_y, vals, Y = next(test_dg)
        
        with torch.no_grad():
            input = to_device(torch.sparse_coo_tensor([idxs_x, idxs_y], vals, size = (config.test_batch_size, config.feature_dim + 1)), device)
            A1 = torch.sparse.mm(input, W1)
            A1 = add_unity_col(A1)
            Z1 = torch.nn.functional.relu(A1)
            A2 = Z1 @ W2
                
            _, preds = torch.max(A2, dim = 1)
            num_correct = 0
            for j in range(A2.shape[0]):
                if len(np.intersect1d(preds[j].cpu(), Y[j])) > 0:
                    num_correct += 1

            accuracies.append(num_correct / config.test_batch_size)
    return np.mean(accuracies)

In [None]:
total_time = 0
begin_time = time.time()
with open(config.log_file, 'a') as out:
    losses = []
    for step in range(n_steps):
        if step % 25 == 0:
            total_time += time.time() - begin_time
            n_steps_val = n_test//batch_size
            test_data_generator = data_generator_tst(config.data_path_test, config.test_batch_size, config.n_classes)
            accuracy = evaluate(20, test_data_generator)
            print('Step:{}  Total_Time:{}  Test_acc:{}'.format(step, total_time, accuracy), file = out)
            print('Step:{}  Total_Time:{}  Test_acc:{}'.format(step, total_time, accuracy))
            begin_time = time.time()
        loss = fit(train_data_generator, step)
        losses.append(loss)