# This notebook is purely to test the aggregation of a 2 dimension tensor is applied correctly.

In [78]:
import torch
import torch.nn as nn

torch.set_printoptions(sci_mode=False)

In [112]:
# create values in 2 dimensions 150x50 (150 rows of 50 class predictions)
# first dimension is number of samples
# second dimension is the 50-class prediction for the sample.
# aggregation is to be summed across each sample. e.g. index 1 in sample 1 + index 1 in sample 2 etc.
# x is normal outputs
# relu_x is x with relu applied (required for logs)

x = torch.randn(size=(150, 50))
relu_x = nn.ReLU()(x)

In [114]:
x

tensor([[ 0.1106, -0.6281, -1.2456,  ...,  0.9155,  1.1275,  0.2143],
        [ 1.6122,  0.5506,  0.0572,  ..., -1.2279,  0.8519,  0.0300],
        [ 0.7115,  0.1835, -0.0796,  ...,  0.0617,  1.2123, -0.7905],
        ...,
        [-1.1068, -0.3756,  1.0011,  ..., -1.1951, -0.8403,  1.6624],
        [ 1.3165,  1.2681, -0.0823,  ...,  0.7479, -0.0807,  1.1807],
        [ 0.9953, -1.8364, -0.6370,  ...,  0.0747, -0.4550,  1.0746]])

In [115]:
relu_x

tensor([[0.1106, 0.0000, 0.0000,  ..., 0.9155, 1.1275, 0.2143],
        [1.6122, 0.5506, 0.0572,  ..., 0.0000, 0.8519, 0.0300],
        [0.7115, 0.1835, 0.0000,  ..., 0.0617, 1.2123, 0.0000],
        ...,
        [0.0000, 0.0000, 1.0011,  ..., 0.0000, 0.0000, 1.6624],
        [1.3165, 1.2681, 0.0000,  ..., 0.7479, 0.0000, 1.1807],
        [0.9953, 0.0000, 0.0000,  ..., 0.0747, 0.0000, 1.0746]])

In [None]:
'''
numerous sequential aggregation types.

- normalisation and sum of all outputs
- normalisation and sum of top 5 outputs
- normalisation and sum of top 1 outputs
- logarithm and sum of all outputs
- logarithm and sum of top 5 outputs
- logarithm and sum of top 1 outputs
- ranking of full outputs
- ranking of top 5 outputs
'''

## sum and count of all/top 5/top 1 outputs

In [164]:
'''
- sum of all outputs
- sum of top 5 outputs
- sum of top 1 outputs
- count of each index in top 5 outputs (if count = same, order by higher sum)
- count of each index in top 1 outputs (if count = same, order by higher sum)
'''

'\n- sum of all outputs\n- sum of top 5 outputs\n- sum of top 1 outputs\n- count of each index in top 5 outputs (if count = same, order by higher sum)\n- count of each index in top 1 outputs (if count = same, order by higher sum)\n'

In [119]:
torch.sum(x, dim=0)

tensor([    21.7677,     21.8678,    -22.9294,     23.7354,      5.0388,
           -27.7955,     21.3355,     -8.5228,     -8.8538,      6.4981,
            -3.6364,      9.3998,    -14.4198,     -8.0694,     -1.4318,
           -41.2709,      4.7210,     -9.1966,      8.8066,     -4.5832,
            -0.4920,     10.3356,     -6.2136,     30.2712,    -14.4335,
           -12.3083,     -7.4460,    -34.5201,      8.7070,     -8.5028,
            12.3348,      0.1847,     -2.5948,      8.5503,      3.9403,
             2.1135,      7.3480,     24.7046,      5.9763,      1.0319,
            -1.9839,     -8.4600,      0.0785,     -7.9150,     -6.9130,
             0.0909,    -19.8926,     -2.0044,     -0.0351,     -3.6663])

In [163]:
top_1_index_sum = {}
top_5_index_sum = {}
for i in range(x.shape[1]):
    top_1_index_sum[i] = (0, 0)
    top_5_index_sum[i] = (0, 0)

top_1 = torch.topk(x, k=1)
top_5 = torch.topk(x, k=5)

for i in range(x.shape[0]):
    top_1_indices = top_1.indices[i]
    top_1_values = top_1.values[i]

    for index, c in enumerate(top_1_indices):
        top_1_index_sum[c.item()] = (top_1_index_sum[c.item()][0] + 1, top_1_index_sum[c.item()][1] + top_1_values[index])

    top_5_indices = top_5.indices[i]
    top_5_values = top_5.values[i]

    for index, c in enumerate(top_5_indices):
        top_5_index_sum[c.item()] = (top_5_index_sum[c.item()][0] + 1, top_5_index_sum[c.item()][1] + top_5_values[index])

# format:
# {index: (count within top X, sum of values within top X)}

top_1_index_sum

{0: (4, tensor(9.4854)),
 1: (2, tensor(5.2710)),
 2: (4, tensor(7.9448)),
 3: (3, tensor(6.8340)),
 4: (2, tensor(4.7406)),
 5: (1, tensor(2.1647)),
 6: (2, tensor(3.7034)),
 7: (3, tensor(5.8296)),
 8: (0, 0),
 9: (1, tensor(3.0801)),
 10: (4, tensor(8.8410)),
 11: (3, tensor(6.9272)),
 12: (2, tensor(5.4251)),
 13: (4, tensor(10.0090)),
 14: (7, tensor(13.5303)),
 15: (2, tensor(4.2658)),
 16: (3, tensor(6.2397)),
 17: (2, tensor(4.9392)),
 18: (2, tensor(4.8260)),
 19: (3, tensor(7.3131)),
 20: (3, tensor(6.5062)),
 21: (5, tensor(11.0374)),
 22: (4, tensor(10.9704)),
 23: (2, tensor(4.5935)),
 24: (2, tensor(5.8893)),
 25: (5, tensor(11.7143)),
 26: (4, tensor(7.7399)),
 27: (3, tensor(6.9936)),
 28: (4, tensor(9.4896)),
 29: (2, tensor(4.6367)),
 30: (4, tensor(9.9389)),
 31: (4, tensor(10.2029)),
 32: (4, tensor(7.9811)),
 33: (4, tensor(9.9205)),
 34: (3, tensor(5.7725)),
 35: (2, tensor(3.8735)),
 36: (1, tensor(1.9098)),
 37: (5, tensor(11.7681)),
 38: (4, tensor(9.0101)),
 3