In [1]:
import torch
import os

In [2]:
import numpy as np

## Config and Data Generators

In [3]:
class config:
    data_path_train = 'dataset/Amazon/amazon_train.txt'
    data_path_test = 'dataset/Amazon/amazon_test.txt'
    GPUs = True
    num_threads = 44 # Only used when GPUs is empty string
    lr = 0.0001
    max_l2 = 6
    sparsity = 0.00002
    feature_dim = 135909
    n_classes = 670091
    n_train = 490449
    n_test = 153025
    n_epochs = 20
    batch_size = 256
    test_batch_size = 256
    hidden_dim = 126
    log_file = 'log_amz_faiss_ivfpq_gpu'

In [4]:
def get_default_device():
    # Pick GPU if available else CPU
    if torch.cuda.is_available() and config.GPUs:
        return torch.device("cuda")
    else:
        return torch.device("cpu")

In [5]:
import os
# If the runtime is connected to Colab Hosted runtime
if "COLAB_GPU" in os.environ:
    config.data_path_train = '/content/drive/MyDrive/Colab Datasets/Amazon/amazon_train.txt'
    config.data_path_test = '/content/drive/MyDrive/Colab Datasets/Amazon/amazon_test.txt'
    from google.colab import drive
    drive.mount('/content/drive')

In [6]:
device = get_default_device()

In [7]:
device = torch.device("cuda:0")
device

device(type='cuda', index=0)

In [8]:
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

In [9]:
batch_size = config.batch_size
n_classes = config.n_classes

In [10]:
from itertools import islice

In [11]:
def data_generator(file_name, batch_size, n_classes):
    while True:
        lines = []
        with open(file_name,'r',encoding='utf-8') as f:
            header = f.readline()  # ignore the header
            while True:
                temp = len(lines)
                lines += list(islice(f,batch_size-temp))
                if len(lines)!=batch_size:
                    break
                idxs_x, idxs_y = [], []
                vals = []
                y_idxs = []
                labels_batch = []
                y_batch = torch.zeros([batch_size, n_classes], dtype = torch.float32, device = device)
                count = 0
                for line in lines:
                    itms = line.strip().split(' ')
                    y_idxs = [int(itm) for itm in itms[0].split(',')]
                    labels_batch.append(y_idxs)
                    y_batch[count, y_idxs] = 1.0/len(y_idxs)
                    temp_idxs_y = [int(itm.split(':')[0]) for itm in itms[1:]]
                    temp_idxs_y.append(config.feature_dim)
                    idxs_y += temp_idxs_y
                    idxs_x += [count] * len(temp_idxs_y)
                    vals += [float(itm.split(':')[1]) for itm in itms[1:]]
                    vals.append(1.0)
                    count += 1
                lines = []
                yield (idxs_x, idxs_y, vals, y_batch, labels_batch)

In [12]:
def data_generator_tst(file_name, batch_size, n_classes):
    while True:
        lines = []
        with open(file_name,'r',encoding='utf-8') as f:
            header = f.readline()  # ignore the header
            while True:
                temp = len(lines)
                lines += list(islice(f,batch_size-temp))
                if len(lines)!=batch_size:
                    break
                idxs_x, idxs_y = [], []
                vals = []
                labels_batch = []
                count = 0
                for line in lines:
                    itms = line.strip().split(' ')
                    y_idxs = [int(itm) for itm in itms[0].split(',')]
                    labels_batch.append(y_idxs)
                    temp_idxs_y = [int(itm.split(':')[0]) for itm in itms[1:]]
                    temp_idxs_y.append(config.feature_dim)
                    idxs_y += temp_idxs_y
                    idxs_x += [count] * len(temp_idxs_y)
                    vals += [float(itm.split(':')[1]) for itm in itms[1:]]
                    vals.append(1.0)
                    count += 1
                lines = []
                yield (idxs_x, idxs_y, vals, labels_batch)

## Torch Training

In [13]:
import faiss

In [14]:
import torch
import time
import math
import torch.nn as nn

In [15]:
n_epochs = config.n_epochs
n_train = config.n_train
n_test = config.n_test
n_check = 50
# n_check = 1
steps_per_epoch = n_train // batch_size
n_steps = config.n_epochs * steps_per_epoch

In [16]:
train_data_generator = data_generator(config.data_path_train, batch_size = config.batch_size, n_classes = config.n_classes)

In [17]:
W1 = torch.randn(config.feature_dim + 1, config.hidden_dim, requires_grad = False)
W2 = torch.randn(config.hidden_dim + 1, config.n_classes, requires_grad = False)

In [18]:
val = 2.0/math.sqrt(config.hidden_dim + 1 + config.n_classes)
a = - 2 * val
b = 2 * val
W1 = nn.init.trunc_normal_(W1, a = a, b = b)
W2 = nn.init.trunc_normal_(W2, a = a, b = b)

In [19]:
config.max_l2 = 3 * 11 * torch.max(torch.abs(W2)).item()

In [20]:
config.max_l2

0.1612374451942742

In [21]:
with torch.no_grad():
    (W1, W2) = to_device((W1, W2), device)
# W1.requires_grad = True
# W2.requires_grad = True

In [22]:
W1.grad = W1.new_zeros(W1.shape)
W2.grad = W2.new_zeros(W2.shape)

In [23]:
W1.grad

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [24]:
print(W1.requires_grad, W2.requires_grad)

False False


In [25]:
adam_optim = torch.optim.Adam(params = (W1, W2), lr = config.lr)

log_softmax = torch.nn.LogSoftmax(dim = 1)
add_unity_col = torch.nn.ConstantPad1d((0, 1), value = 1.0)
add_zero_col = torch.nn.ConstantPad1d((0, 1), value = 0.0)

In [26]:
gpu_resources = faiss.StandardGpuResources()

In [27]:
dim = (config.hidden_dim + 1) + 1
nlist = 512
nbits = 8
m = 32
metric = faiss.METRIC_INNER_PRODUCT
quantizer = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits, faiss.METRIC_INNER_PRODUCT)
gpu_index = index

In [28]:
gpu_index.nprobe = 16

In [29]:
gpu_index.metric_arg, gpu_index.metric_type # 0 means inner product

(0.0, 0)

In [30]:
def add_items_to_index(index, data, ids, return_data_only = False):
    with torch.no_grad():
        temp_items = torch.zeros(size = (data.shape[1], (config.hidden_dim + 1) + 1))
        temp_items[:,:-1] = data.T / config.max_l2
        temp_items[:, -1] = torch.sqrt(1.00 - torch.sum(temp_items.pow(2), dim = 1))
        temp_items = np.array(temp_items)
        if not return_data_only:
            if not isinstance(ids, np.ndarray):
                index.add_with_ids(temp_items, ids = np.array(ids.cpu()).astype(np.int64))
            else:
                index.add_with_ids(temp_items, ids = np.array(ids).astype(np.int64))
        else:
            return temp_items
  
def query_items_from_index(index, query, k):
    with torch.no_grad():
        aug_query = add_zero_col(query)
        aug_query = torch.nn.functional.normalize(aug_query, p = 2, dim = 1)
        aug_query = np.array(aug_query)
        return index.search(aug_query, k)

In [31]:
rebuild_delay = 50
decay_rate = 1.15
steps_since_rebuild = 0

In [32]:
gpu_index.is_trained

False

In [33]:
def rebuild_index(train_index = True):
    global gpu_index

    config.max_l2 = 3 * 11 * torch.max(torch.abs(W2)).item()

    normalized_init_data = add_items_to_index(gpu_index, W2, torch.arange(config.n_classes), return_data_only=True)
    normalized_init_data = np.array(normalized_init_data)
    gpu_index.train(normalized_init_data)
    add_items_to_index(gpu_index, W2, ids = np.arange(config.n_classes))
    print("Index Rebuilt, max_l2: ", config.max_l2)
    print("Rebuild delay", rebuild_delay)

In [34]:
rebuild_index()

Index Rebuilt, max_l2:  0.1612374451942742
Rebuild delay 50


In [35]:
def fit(train_dg, step):
        global steps_since_rebuild, rebuild_delay, decay_rate
        try:
            with torch.no_grad():
                adam_optim.zero_grad()
                idxs_x, idxs_y, vals, Y, labels = next(train_dg)
                
                # Feed Forward
                input = to_device(torch.sparse_coo_tensor([idxs_x, idxs_y], vals, size = (batch_size, config.feature_dim + 1)), device)
                A1 = torch.sparse.mm(input, W1)
                A1 = add_unity_col(A1)
                Z1 = torch.nn.functional.relu(A1)
                # ## HNSW layer2 indices query logic
                _, layer2_idxs = query_items_from_index(gpu_index, query = Z1.cpu(), k = int(config.sparsity * config.n_classes))
                # print(layer2_idxs)
                layer2_idxs = layer2_idxs.flatten()
                labels = np.array([x for sub in labels for x in sub])
                layer2_idxs = np.array(layer2_idxs)
                layer2_idxs = np.union1d(layer2_idxs, labels)

                if max(layer2_idxs) >= config.n_classes or min(layer2_idxs) < 0: 
                    print(max(layer2_idxs), min(layer2_idxs))
                layer2_idxs %= config.n_classes


                #  Sparse Feed Forward
                A2 = Z1 @ W2[:, layer2_idxs]

                P = log_softmax(A2)
                L = -P * Y[:, layer2_idxs]
                L = torch.mean(torch.sum(L, dim = 1))

                # Sparse Back Propagation
                PL = torch.exp(P)
                temp_B2_grad = (PL - Y[:, layer2_idxs])
                W2.grad[:, layer2_idxs] = Z1.T @ temp_B2_grad
                temp_B1_grad = temp_B2_grad @ W2[:, layer2_idxs].T
                temp_B1_grad[A1 < 0] = 0
                W1.grad = torch.sparse.mm(input.t(), temp_B1_grad[:, :-1])  

                adam_optim.step()

                steps_since_rebuild += 1
                if steps_since_rebuild >= rebuild_delay:
                    rebuild_index()
                    steps_since_rebuild = 0
                    rebuild_delay = int( rebuild_delay * decay_rate )

        except Exception as e:
            print("Exception")
            print(e)
        finally:
            return 1.0

In [36]:
import numpy as np

In [37]:
def evaluate(n_steps, test_dg):
    accuracies = []
    for h in range(n_steps):
        idxs_x, idxs_y, vals, Y = next(test_dg)
        
        with torch.no_grad():
            input = to_device(torch.sparse_coo_tensor([idxs_x, idxs_y], vals, size = (config.test_batch_size, config.feature_dim + 1)), device)
            A1 = torch.sparse.mm(input, W1)
            A1 = add_unity_col(A1)
            Z1 = torch.nn.functional.relu(A1)
            A2 = Z1 @ W2
                
            _, preds = torch.max(A2, dim = 1)
            num_correct = 0
            for j in range(A2.shape[0]):
                if len(np.intersect1d(preds[j].cpu(), Y[j])) > 0:
                    num_correct += 1

            accuracies.append(num_correct / config.test_batch_size)
    return np.mean(accuracies)

In [None]:
total_time = 0
begin_time = time.time()
with open(config.log_file, 'a') as out:
    losses = []
    for step in range(n_steps):
        if step % n_check == 0:
            total_time += time.time() - begin_time
            n_steps_val = n_test//batch_size
            test_data_generator = data_generator_tst(config.data_path_test, config.test_batch_size, config.n_classes)
            accuracy = evaluate(20, test_data_generator)
            print('Step:{}  Total_Time:{}  Test_acc:{}'.format(step, total_time, accuracy), file = out)
            print('Step:{}  Total_Time:{}  Test_acc:{}'.format(step, total_time, accuracy))
            begin_time = time.time()
        if step % steps_per_epoch == (steps_per_epoch - 1):
            total_time += time.time() - begin_time
            n_steps_val = n_test//batch_size
            test_data_generator = data_generator_tst(config.data_path_test, config.test_batch_size, config.n_classes)
            
            accuracy = evaluate(n_steps_val, test_data_generator) #checking precision on the complete test data
            print('OVERALL Step : {} Total_Time: {} Test_acc: {}'.format(step, total_time, accuracy), file = out)
            print('OVERALL Step : {} Total_Time: {} Test_acc: {}'.format(step, total_time, accuracy))
            begin_time = time.time()
        loss = fit(train_data_generator, step)
        losses.append(loss)

Step:0  Total_Time:0.0009860992431640625  Test_acc:0.0
Index Rebuilt, max_l2:  0.32244317326694727
Rebuild delay 50
Step:50  Total_Time:26.56253695487976  Test_acc:0.0134765625
Step:100  Total_Time:30.045524835586548  Test_acc:0.00859375
Index Rebuilt, max_l2:  0.6073473505675793
Rebuild delay 57
Step:150  Total_Time:56.38411259651184  Test_acc:0.008203125
Index Rebuilt, max_l2:  0.7743756137788296
Rebuild delay 65
Step:200  Total_Time:83.49229693412781  Test_acc:0.01796875
Index Rebuilt, max_l2:  0.9347339998930693
Rebuild delay 74
Step:250  Total_Time:111.24041438102722  Test_acc:0.0197265625
Step:300  Total_Time:115.46797561645508  Test_acc:0.01796875
Index Rebuilt, max_l2:  1.168583169579506
Rebuild delay 85
Step:350  Total_Time:143.49398612976074  Test_acc:0.0216796875
Step:400  Total_Time:148.12948346138  Test_acc:0.021484375
Index Rebuilt, max_l2:  1.3420818485319614
Rebuild delay 97
Step:450  Total_Time:177.1778860092163  Test_acc:0.0251953125
Step:500  Total_Time:182.118142366