# This notebook is purely to test the aggregation of a 2 dimension tensor is applied correctly.

In [9]:
import torch
import torch.nn as nn

torch.set_printoptions(sci_mode=False)

In [46]:
# create values in 2 dimensions 150x50 (150 rows of 50 class predictions)
# first dimension is number of samples
# second dimension is the 50-class prediction for the sample.
# aggregation is to be summed across each sample. e.g. index 1 in sample 1 + index 1 in sample 2 etc.
# x is normal outputs
# relu_x is x with relu applied (required for logs)

torch.manual_seed(1368)

x = torch.randn(size=(150, 50))
relu_x = nn.ReLU()(x)
norm_x = nn.functional.normalize(x, dim=0)
log_x = torch.log(relu_x + 1) # relu_x + 1 has absolute minimum of 1, which = 0 for any log.

In [47]:
x

tensor([[-0.6936,  0.2504, -1.7355,  ...,  0.3461, -1.5154,  0.3922],
        [-1.5361,  2.0293, -0.1356,  ..., -1.2285,  1.4470, -2.5613],
        [ 0.9282,  0.8371,  0.7941,  ..., -0.8950, -0.7631, -1.7213],
        ...,
        [-1.4624, -0.8111, -0.3618,  ..., -0.5492,  0.2543, -1.1599],
        [-0.5550, -0.1048, -1.2397,  ...,  0.8518,  0.4345,  1.2991],
        [ 0.1281,  2.5001,  0.5547,  ...,  0.6481,  0.5052, -0.1094]])

In [48]:
relu_x

tensor([[0.0000, 0.2504, 0.0000,  ..., 0.3461, 0.0000, 0.3922],
        [0.0000, 2.0293, 0.0000,  ..., 0.0000, 1.4470, 0.0000],
        [0.9282, 0.8371, 0.7941,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2543, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.8518, 0.4345, 1.2991],
        [0.1281, 2.5001, 0.5547,  ..., 0.6481, 0.5052, 0.0000]])

In [49]:
norm_x

tensor([[-0.0599,  0.0201, -0.1452,  ...,  0.0277, -0.1317,  0.0326],
        [-0.1327,  0.1628, -0.0113,  ..., -0.0984,  0.1258, -0.2127],
        [ 0.0802,  0.0672,  0.0664,  ..., -0.0717, -0.0663, -0.1430],
        ...,
        [-0.1263, -0.0651, -0.0303,  ..., -0.0440,  0.0221, -0.0963],
        [-0.0479, -0.0084, -0.1037,  ...,  0.0682,  0.0378,  0.1079],
        [ 0.0111,  0.2006,  0.0464,  ...,  0.0519,  0.0439, -0.0091]])

In [50]:
log_x

tensor([[0.0000, 0.2234, 0.0000,  ..., 0.2972, 0.0000, 0.3309],
        [0.0000, 1.1083, 0.0000,  ..., 0.0000, 0.8949, 0.0000],
        [0.6566, 0.6082, 0.5845,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2266, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.6162, 0.3608, 0.8325],
        [0.1205, 1.2528, 0.4413,  ..., 0.4996, 0.4089, 0.0000]])

In [13]:
'''
numerous sequential aggregation types.

- normalisation and sum of all outputs
- normalisation and sum of top 5 outputs
- normalisation and sum of top 1 outputs
- logarithm and sum of all outputs
- logarithm and sum of top 5 outputs
- logarithm and sum of top 1 outputs
- ranking of full outputs
- ranking of top 5 outputs
'''

'\nnumerous sequential aggregation types.\n\n- normalisation and sum of all outputs\n- normalisation and sum of top 5 outputs\n- normalisation and sum of top 1 outputs\n- logarithm and sum of all outputs\n- logarithm and sum of top 5 outputs\n- logarithm and sum of top 1 outputs\n- ranking of full outputs\n- ranking of top 5 outputs\n'

## sum and count of all/top 5/top 1 outputs

In [14]:
'''
- sum of all outputs
- sum of top 5 outputs
- sum of top 1 outputs
- count of each index in top 5 outputs (if count = same, order by higher sum)
- count of each index in top 1 outputs (if count = same, order by higher sum)
'''

'\n- sum of all outputs\n- sum of top 5 outputs\n- sum of top 1 outputs\n- count of each index in top 5 outputs (if count = same, order by higher sum)\n- count of each index in top 1 outputs (if count = same, order by higher sum)\n'

In [15]:
torch.sum(x, dim=0)

tensor([ 24.4113,  -2.4563, -16.7200, -27.3724,   3.8512,  17.0184,  -1.2498,
          9.7502,  15.1852,  -5.0573, -19.5892,  17.6351,  -9.8900, -11.5701,
         -8.4774,  12.9137,  12.3836, -17.2963, -15.0070,  11.4427, -19.4745,
         12.2673,  -1.2924, -13.6465,  10.1143, -15.6040,   0.0392,  14.4163,
         30.2627,  10.7200,   8.6177,   5.8099,  13.8965,   9.7357,   4.1462,
         -4.0059,   3.8048,  -1.7831,   9.4026,  -9.5103,  10.1094,  29.6169,
          2.2838,  -1.7679, -15.6548,  -0.2831, -26.4770, -12.0819,  -4.8699,
        -13.2558])

In [16]:
top_1_index_sum = {}
top_5_index_sum = {}
for i in range(x.shape[1]):
    top_1_index_sum[i] = (0, 0)
    top_5_index_sum[i] = (0, 0)

top_1 = torch.topk(x, k=1)
top_5 = torch.topk(x, k=5)

for i in range(x.shape[0]):
    top_1_indices = top_1.indices[i]
    top_1_values = top_1.values[i]

    for index, c in enumerate(top_1_indices):
        top_1_index_sum[c.item()] = (top_1_index_sum[c.item()][0] + 1, top_1_index_sum[c.item()][1] + top_1_values[index])

    top_5_indices = top_5.indices[i]
    top_5_values = top_5.values[i]

    for index, c in enumerate(top_5_indices):
        top_5_index_sum[c.item()] = (top_5_index_sum[c.item()][0] + 1, top_5_index_sum[c.item()][1] + top_5_values[index])

# format:
# {index: (count within top X, sum of values within top X)}

top_1_index_sum

{0: (3, tensor(7.4754)),
 1: (4, tensor(9.3746)),
 2: (2, tensor(4.7576)),
 3: (2, tensor(6.5418)),
 4: (4, tensor(9.4000)),
 5: (5, tensor(10.5135)),
 6: (4, tensor(9.1883)),
 7: (3, tensor(5.9449)),
 8: (4, tensor(8.7903)),
 9: (1, tensor(1.9332)),
 10: (2, tensor(4.2875)),
 11: (6, tensor(14.2233)),
 12: (1, tensor(2.5518)),
 13: (4, tensor(10.2946)),
 14: (6, tensor(12.8990)),
 15: (5, tensor(11.4788)),
 16: (4, tensor(7.6930)),
 17: (4, tensor(7.2684)),
 18: (2, tensor(4.9027)),
 19: (3, tensor(7.2423)),
 20: (0, 0),
 21: (2, tensor(4.3348)),
 22: (6, tensor(16.1078)),
 23: (0, 0),
 24: (4, tensor(8.9296)),
 25: (2, tensor(4.8293)),
 26: (1, tensor(2.8202)),
 27: (2, tensor(4.6864)),
 28: (1, tensor(1.8333)),
 29: (1, tensor(2.0958)),
 30: (4, tensor(10.2733)),
 31: (3, tensor(7.1980)),
 32: (5, tensor(12.5287)),
 33: (5, tensor(10.9097)),
 34: (4, tensor(9.2821)),
 35: (1, tensor(1.4322)),
 36: (2, tensor(4.6496)),
 37: (2, tensor(3.9213)),
 38: (0, 0),
 39: (6, tensor(12.2138)),

## logarithm count/sum for all/top 5/top 1

In [51]:
torch.sum(log_x, dim=0)

tensor([45.2204, 40.2728, 33.3566, 33.5029, 44.7026, 46.9850, 40.3652, 42.7795,
        43.9073, 37.1153, 33.5561, 45.5006, 34.3956, 36.5887, 42.5320, 43.9331,
        40.7051, 32.1633, 37.9877, 43.7769, 34.7570, 47.4297, 41.2366, 35.8988,
        47.0249, 37.0641, 40.8261, 41.6866, 51.3250, 39.9896, 43.4366, 42.3180,
        45.3651, 46.2926, 41.2046, 39.5955, 42.2172, 42.3869, 40.6792, 40.3541,
        42.4828, 52.0804, 40.7880, 41.2062, 36.9697, 40.0917, 31.8729, 37.7731,
        37.9598, 33.7015])

In [52]:
top_1_log_index_sum = {}
top_5_log_index_sum = {}
for i in range(x.shape[1]):
    top_1_log_index_sum[i] = (0, 0)
    top_5_log_index_sum[i] = (0, 0)

top_1 = torch.topk(log_x, k=1)
top_5 = torch.topk(log_x, k=5)

for i in range(x.shape[0]):
    top_1_indices = top_1.indices[i]
    top_1_values = top_1.values[i]

    for index, c in enumerate(top_1_indices):
        top_1_log_index_sum[c.item()] = (top_1_log_index_sum[c.item()][0] + 1, top_1_log_index_sum[c.item()][1] + top_1_values[index])

    top_5_indices = top_5.indices[i]
    top_5_values = top_5.values[i]

    for index, c in enumerate(top_5_indices):
        top_5_log_index_sum[c.item()] = (top_5_log_index_sum[c.item()][0] + 1, top_5_log_index_sum[c.item()][1] + top_5_values[index])

# format:
# {index: (count within top X, sum of values within top X)}

top_1_log_index_sum

{0: (3, tensor(3.7392)),
 1: (4, tensor(4.8187)),
 2: (2, tensor(2.4226)),
 3: (2, tensor(2.8393)),
 4: (4, tensor(4.8184)),
 5: (5, tensor(5.6425)),
 6: (4, tensor(4.7578)),
 7: (3, tensor(3.2752)),
 8: (4, tensor(4.6343)),
 9: (1, tensor(1.0761)),
 10: (2, tensor(2.2546)),
 11: (6, tensor(7.2818)),
 12: (1, tensor(1.2675)),
 13: (4, tensor(5.0583)),
 14: (6, tensor(6.8529)),
 15: (5, tensor(5.9534)),
 16: (4, tensor(4.2849)),
 17: (4, tensor(4.1417)),
 18: (2, tensor(2.4744)),
 19: (3, tensor(3.6836)),
 20: (0, 0),
 21: (2, tensor(2.2989)),
 22: (6, tensor(7.7820)),
 23: (0, 0),
 24: (4, tensor(4.6314)),
 25: (2, tensor(2.4529)),
 26: (1, tensor(1.3403)),
 27: (2, tensor(2.4036)),
 28: (1, tensor(1.0414)),
 29: (1, tensor(1.1301)),
 30: (4, tensor(5.0856)),
 31: (3, tensor(3.6584)),
 32: (5, tensor(6.2345)),
 33: (5, tensor(5.7630)),
 34: (4, tensor(4.7845)),
 35: (1, tensor(0.8888)),
 36: (2, tensor(2.3996)),
 37: (2, tensor(2.1703)),
 38: (0, 0),
 39: (6, tensor(6.6264)),
 40: (2, 