In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# attn_idc: [batch_size, num_classes]
# attn_weights: [batch_size, max_seg_length, n_attns]

In [3]:
eps = 1e-6

In [4]:
batch_size, num_classes, max_seg_length, max_seg_label_length, n_attns = 2, 6, 7, 6, 3

In [5]:
seg_labels = torch.LongTensor([[6, 6, 4, 1, 1, 4], [2, 2, 1, 2, 0, 0]])
seg_times = torch.LongTensor([[3, 7, 2, 1, 2, 3], [2, 6, 3, 4, 0, 0]])

In [6]:
attn_weights = torch.randn(batch_size, max_seg_length, n_attns)
attn_weights

tensor([[[-1.0851, -0.3736, -0.1189],
         [ 0.7270,  1.0704,  0.3717],
         [-0.3377,  1.1548, -0.1457],
         [-0.2145,  1.4151, -1.9948],
         [ 0.1310, -0.1755,  0.1703],
         [-1.0556, -0.5131, -0.5723],
         [ 1.0823,  0.0350, -1.4357]],

        [[-1.0630,  0.5343, -1.3635],
         [ 0.8366,  3.1758,  0.6835],
         [ 1.2344,  0.9112, -1.3197],
         [-0.6442,  2.0242, -1.2948],
         [ 0.0755, -0.8215, -1.4158],
         [ 0.2398,  0.0094, -0.4789],
         [ 1.0044, -0.1572,  0.8612]]])

In [7]:
# attn_weights: [batch_size, seg_length, n_attns]
attn_weights = F.softmax(attn_weights, dim=1)
attn_weights

tensor([[[0.0404, 0.0515, 0.1667],
         [0.2473, 0.2182, 0.2723],
         [0.0853, 0.2374, 0.1623],
         [0.0964, 0.3080, 0.0255],
         [0.1363, 0.0628, 0.2226],
         [0.0416, 0.0448, 0.1059],
         [0.3528, 0.0775, 0.0447]],

        [[0.0295, 0.0449, 0.0426],
         [0.1974, 0.6299, 0.3298],
         [0.2938, 0.0654, 0.0445],
         [0.0449, 0.1992, 0.0456],
         [0.0922, 0.0116, 0.0404],
         [0.1087, 0.0266, 0.1031],
         [0.2334, 0.0225, 0.3940]]])

In [8]:
# attn_idc: [batch_size, num_classes]
attn_idc = torch.randint(n_attns, (batch_size, num_classes))
attn_idc

tensor([[2, 2, 0, 2, 2, 0],
        [2, 2, 0, 1, 1, 2]])

In [9]:
zeros = torch.zeros(batch_size, 1).long()
zeros

tensor([[0],
        [0]])

In [10]:
attn_idc = torch.cat((zeros, attn_idc), dim=1)
attn_idc

tensor([[0, 2, 2, 0, 2, 2, 0],
        [0, 2, 2, 0, 1, 1, 2]])

In [11]:
#seg_labels = torch.LongTensor([[6, 6, 4, 1, 1, 4], [2, 2, 1, 2, 0, 0]])
#seg_times = torch.LongTensor([[3, 7, 2, 1, 2, 3], [2, 6, 3, 4, 0, 0]])

In [12]:
# selected_attn_idc: [batch_size, max_seg_label_length]
selected_attn_idc = torch.gather(attn_idc, 1, seg_labels)
selected_attn_idc

tensor([[0, 0, 2, 2, 2, 2],
        [2, 2, 2, 2, 0, 0]])

In [13]:
mask = seg_labels.float().ge(0.5)
mask

tensor([[ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True, False, False]])

In [14]:
# attn_weights: [batch_size, n_attns, seg_length]
attn_weights = attn_weights.transpose(1, 2)
attn_weights

tensor([[[0.0404, 0.2473, 0.0853, 0.0964, 0.1363, 0.0416, 0.3528],
         [0.0515, 0.2182, 0.2374, 0.3080, 0.0628, 0.0448, 0.0775],
         [0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447]],

        [[0.0295, 0.1974, 0.2938, 0.0449, 0.0922, 0.1087, 0.2334],
         [0.0449, 0.6299, 0.0654, 0.1992, 0.0116, 0.0266, 0.0225],
         [0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940]]])

In [15]:
# Batched index_select
def batched_index_select(t, dim, inds):
    dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
    out = t.gather(dim, dummy) # b x e x f
    return out

In [16]:
# selected_attn_weights: [batch_size, max_seg_label_length, max_seg_length]
selected_attn_weights = batched_index_select(attn_weights, 1, selected_attn_idc)
selected_attn_weights

tensor([[[0.0404, 0.2473, 0.0853, 0.0964, 0.1363, 0.0416, 0.3528],
         [0.0404, 0.2473, 0.0853, 0.0964, 0.1363, 0.0416, 0.3528],
         [0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447],
         [0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447],
         [0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447],
         [0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447]],

        [[0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0295, 0.1974, 0.2938, 0.0449, 0.0922, 0.1087, 0.2334],
         [0.0295, 0.1974, 0.2938, 0.0449, 0.0922, 0.1087, 0.2334]]])

In [17]:
#seg_labels = torch.LongTensor([[6, 6, 4, 1, 1, 4], [2, 2, 1, 2, 0, 0]])
#seg_times = torch.LongTensor([[3, 7, 2, 1, 2, 3], [2, 6, 3, 4, 0, 0]])

In [18]:
zeros = torch.zeros(batch_size, max_seg_label_length, 1)

In [19]:
selected_attn_weights = torch.cat((zeros, selected_attn_weights), dim=2)
selected_attn_weights

tensor([[[0.0000, 0.0404, 0.2473, 0.0853, 0.0964, 0.1363, 0.0416, 0.3528],
         [0.0000, 0.0404, 0.2473, 0.0853, 0.0964, 0.1363, 0.0416, 0.3528],
         [0.0000, 0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447],
         [0.0000, 0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447],
         [0.0000, 0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447],
         [0.0000, 0.1667, 0.2723, 0.1623, 0.0255, 0.2226, 0.1059, 0.0447]],

        [[0.0000, 0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0000, 0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0000, 0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0000, 0.0426, 0.3298, 0.0445, 0.0456, 0.0404, 0.1031, 0.3940],
         [0.0000, 0.0295, 0.1974, 0.2938, 0.0449, 0.0922, 0.1087, 0.2334],
         [0.0000, 0.0295, 0.1974, 0.2938, 0.0449, 0.0922, 0.1087, 0.2334]]])

In [20]:
# seg_times: [batch_size, max_seg_label_length, 1]
seg_times = seg_times.unsqueeze(2)
seg_times

tensor([[[3],
         [7],
         [2],
         [1],
         [2],
         [3]],

        [[2],
         [6],
         [3],
         [4],
         [0],
         [0]]])

In [21]:
# selected_attn_weights: [batch_size, max_seg_label_length, 1]
selected_attn_weights = torch.gather(selected_attn_weights, 2, seg_times)
selected_attn_weights

tensor([[[0.0853],
         [0.3528],
         [0.2723],
         [0.1667],
         [0.2723],
         [0.1623]],

        [[0.3298],
         [0.1031],
         [0.0445],
         [0.0456],
         [0.0000],
         [0.0000]]])

In [22]:
selected_attn_weights = selected_attn_weights.squeeze(2)
selected_attn_weights

tensor([[0.0853, 0.3528, 0.2723, 0.1667, 0.2723, 0.1623],
        [0.3298, 0.1031, 0.0445, 0.0456, 0.0000, 0.0000]])

In [23]:
loss = -1 * torch.log(selected_attn_weights + eps).masked_select(mask)
loss

tensor([2.4619, 1.0420, 1.3010, 1.7916, 1.3010, 1.8184, 1.1093, 2.2716, 3.1124,
        3.0875])