In [None]:
import torch
import numpy as np
from itertools import islice

## Config and Data Generators

In [None]:
class config:
    data_path_train = 'Datasets/Amazon/amazon_train.txt'
    data_path_test = 'Datasets/Amazon/amazon_test.txt'
    GPUs = True
    num_threads = 44 # Only used when GPUs is empty string
    lr = 0.0001
    sparsity = 0.005
    max_l2 = 50
    feature_dim = 135909
    n_classes = 670091
    n_train = 490449
    n_test = 153025
    n_epochs = 20
    batch_size = 256
    test_batch_size = 256
    hidden_dim = 128
    log_file = 'log_amz_torch_sparse_matrix_index'

In [None]:
def get_default_device():
    # Pick GPU if available else CPU
    if torch.cuda.is_available() and config.GPUs:
        return torch.device("cuda")
    else:
        return torch.device("cpu")

In [None]:
import os
# If the runtime is connected to Colab Hosted runtime
if "COLAB_GPU" in os.environ:
    config.data_path_train = '/content/drive/MyDrive/Colab Datasets/Amazon/amazon_train.txt'
    config.data_path_test = '/content/drive/MyDrive/Colab Datasets/Amazon/amazon_test.txt'
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
device = get_default_device()

In [None]:
device

device(type='cuda')

In [None]:
if device.type == "cuda" and ("COLAB_GPU" not in os.environ):
    device = torch.device("cuda:1")

In [None]:
def to_device(data, device):
    # Move tensor(s) to chosen device
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

In [None]:
# ! cat /content/drive/MyDrive/Colab\ Datasets/Amazon/amazon_train_sample.txt > sample.txt

In [None]:
batch_size = config.batch_size
n_classes = config.n_classes

In [None]:
def data_generator(file_name, batch_size, n_classes):
    while True:
        lines = []
        with open(file_name,'r',encoding='utf-8') as f:
            header = f.readline()  # ignore the header
            while True:
                temp = len(lines)
                lines += list(islice(f,batch_size-temp))
                if len(lines)!=batch_size:
                    break
                idxs_x, idxs_y = [], []
                vals = []
                y_idxs = []
                labels_batch = []
                y_batch = torch.zeros([batch_size, n_classes], dtype = torch.float32, device = device)
                count = 0
                for line in lines:
                    itms = line.strip().split(' ')
                    y_idxs = [int(itm) for itm in itms[0].split(',')]
                    labels_batch.append(y_idxs)
                    y_batch[count, y_idxs] = 1.0/len(y_idxs)
                    temp_idxs_y = [int(itm.split(':')[0]) for itm in itms[1:]]
                    temp_idxs_y.append(config.feature_dim)
                    idxs_y += temp_idxs_y
                    idxs_x += [count] * len(temp_idxs_y)
                    vals += [float(itm.split(':')[1]) for itm in itms[1:]]
                    vals.append(1.0)
                    count += 1
                lines = []
                yield (idxs_x, idxs_y, vals, y_batch, labels_batch)

In [None]:
def data_generator_tst(file_name, batch_size, n_classes):
    while True:
        lines = []
        with open(file_name,'r',encoding='utf-8') as f:
            header = f.readline()  # ignore the header
            while True:
                temp = len(lines)
                lines += list(islice(f,batch_size-temp))
                if len(lines)!=batch_size:
                    break
                idxs_x, idxs_y = [], []
                vals = []
                labels_batch = []
                count = 0
                for line in lines:
                    itms = line.strip().split(' ')
                    y_idxs = [int(itm) for itm in itms[0].split(',')]
                    labels_batch.append(y_idxs)
                    temp_idxs_y = [int(itm.split(':')[0]) for itm in itms[1:]]
                    temp_idxs_y.append(config.feature_dim)
                    idxs_y += temp_idxs_y
                    idxs_x += [count] * len(temp_idxs_y)
                    vals += [float(itm.split(':')[1]) for itm in itms[1:]]
                    vals.append(1.0)
                    count += 1
                lines = []
                yield (idxs_x, idxs_y, vals, labels_batch)

## Torch Training

In [None]:
import torch
import time
import numpy as np
import math
import torch.nn as nn

In [None]:
train_data_generator = data_generator(config.data_path_train, batch_size = config.batch_size, n_classes = config.n_classes)

In [None]:
n_epochs = config.n_epochs
n_train = config.n_train
n_test = config.n_test
n_check = 50
steps_per_epoch = n_train // batch_size
n_steps = config.n_epochs * steps_per_epoch

In [None]:
train_data_generator = data_generator(config.data_path_train, batch_size = config.batch_size, n_classes = config.n_classes)

In [None]:
W1 = torch.randn(config.feature_dim + 1, config.hidden_dim, requires_grad = False)

W2 = torch.randn(config.hidden_dim + 1, config.n_classes, requires_grad = False)

In [None]:
W1 = nn.init.trunc_normal_(W1, std = 2.0/math.sqrt(config.feature_dim + 1 + config.hidden_dim))
W2 = nn.init.trunc_normal_(W2, std = 2.0/math.sqrt(config.hidden_dim + 1 + config.n_classes))

In [None]:
with torch.no_grad():
    (W1, W2) = to_device((W1, W2), device)

In [None]:
W1.grad = W1.new_zeros(W1.shape)
W2.grad = W2.new_zeros(W2.shape)

In [None]:
W1.grad

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')

In [None]:
print(W1.requires_grad, W2.requires_grad)

False False


In [None]:
adam_optim = torch.optim.Adam(params = (W1, W2), lr = config.lr)

log_softmax = torch.nn.LogSoftmax(dim = 1)
add_unity_col = torch.nn.ConstantPad1d((0, 1), value = 1.0)
add_zero_col = torch.nn.ConstantPad1d((0, 1), value = 0.0)

In [None]:
layer2_idxs = np.random.choice(np.arange(config.n_classes), int(config.sparsity * config.n_classes))

In [None]:
def fit(train_dg):
    with torch.no_grad():
        adam_optim.zero_grad()
        idxs_x, idxs_y, vals, Y, labels = next(train_dg)
        
        # Feed Forward
        input = torch.sparse_coo_tensor([idxs_x, idxs_y], vals, size = (batch_size, config.feature_dim + 1), device = device)
        A1 = torch.sparse.mm(input, W1) 
        A1 = add_unity_col(A1)
        Z1 = torch.nn.functional.relu(A1)

        #  Sparse Feed Forward
        A2 = Z1 @ W2[:, layer2_idxs]
        P = log_softmax(A2)
        L = -P * Y[:, layer2_idxs]
        L = torch.mean(torch.sum(L, dim = 1))

        # Sparse Back Propagation
        PL = torch.exp(P)
        temp_B2_grad = (PL - Y[:, layer2_idxs])
        W2.grad[:, layer2_idxs] = Z1.T @ temp_B2_grad
        temp_B1_grad = temp_B2_grad @ W2[:, layer2_idxs].T
        temp_B1_grad[A1 < 0] = 0
        W1.grad = torch.sparse.mm(input.t(), temp_B1_grad[:, :-1]) 

        # Update Parameters
        adam_optim.step()


In [None]:
def evaluate(n_steps, test_dg):
    accuracies = []
    for h in range(n_steps):
        idxs_x, idxs_y, vals, Y = next(test_dg)
        
        with torch.no_grad():
            input = torch.sparse_coo_tensor([idxs_x, idxs_y], vals, size = (config.test_batch_size, config.feature_dim + 1), device = device)
            A1 = torch.sparse.mm(input, W1)
            A1 = add_unity_col(A1)
            Z1 = torch.nn.functional.relu(A1)
            A2 = Z1 @ W2
                
            _, preds = torch.max(A2, dim = 1)
            num_correct = 0
            for j in range(A2.shape[0]):
                if len(np.intersect1d(preds[j].cpu(), Y[j])) > 0:
                    num_correct += 1

            accuracies.append(num_correct / config.test_batch_size)
    return np.mean(accuracies)

In [None]:
total_time = 0
begin_time = time.time()
with open(config.log_file, 'a') as out:
    losses = []
    for step in range(n_steps):
        if step % n_check == 0:
            total_time += time.time() - begin_time
            n_steps_val = n_test//batch_size
            test_data_generator = data_generator_tst(config.data_path_test, config.test_batch_size, config.n_classes)
            
            accuracy = evaluate(20, test_data_generator)
            print('Step:{}  Total_Time:{}  Test_acc:{}'.format(step, total_time, accuracy), file = out)
            print('Step:{}  Total_Time:{}  Test_acc:{}'.format(step, total_time, accuracy))
            begin_time = time.time()

        if step % steps_per_epoch == (steps_per_epoch - 1):
            total_time += time.time() - begin_time
            n_steps_val = n_test//batch_size
            test_data_generator = data_generator_tst(config.data_path_test, config.test_batch_size, config.n_classes)
            
            accuracy = evaluate(n_steps_val, test_data_generator)
            print('OVERALL Step : {}  Total_Time: {}  Test_acc: {}'.format(step, total_time, accuracy), file = out)
            print('OVERALL Step : {}  Total_Time: {}  Test_acc: {}'.format(step, total_time, accuracy))
            begin_time = time.time()
        loss = fit(train_data_generator)
        losses.append(loss)

Step:0  Total_Time:0.0004982948303222656  Test_acc:0.0
Step:50  Total_Time:2.305743932723999  Test_acc:0.0
Step:100  Total_Time:4.603775501251221  Test_acc:0.0
Step:150  Total_Time:6.941100120544434  Test_acc:0.0
Step:200  Total_Time:9.254879713058472  Test_acc:0.0
Step:250  Total_Time:11.540871620178223  Test_acc:0.0
Step:300  Total_Time:13.816704034805298  Test_acc:0.0
Step:350  Total_Time:16.08738660812378  Test_acc:0.0
Step:400  Total_Time:18.374639987945557  Test_acc:0.0
Step:450  Total_Time:20.682270288467407  Test_acc:0.0
Step:500  Total_Time:22.96458101272583  Test_acc:0.0
Step:550  Total_Time:25.244523525238037  Test_acc:0.0
Step:600  Total_Time:27.523837566375732  Test_acc:0.0
Step:650  Total_Time:29.793498516082764  Test_acc:0.0
Step:700  Total_Time:32.048781394958496  Test_acc:0.0
Step:750  Total_Time:34.31655263900757  Test_acc:0.0
Step:800  Total_Time:36.60205912590027  Test_acc:0.0
Step:850  Total_Time:38.86902379989624  Test_acc:0.0
Step:900  Total_Time:41.1608390808105