In [44]:
import torch
import torch.optim as optim
import torch.nn as nn
import model.dataset as ds
import model.models

In [2]:
d = ds.SignalDataset(root_dir='data/a1_spectrograms/')

In [3]:
dataloader = torch.utils.data.DataLoader(d, batch_size=32, shuffle=True)

In [4]:
for i, info in enumerate(dataloader):
    aggregate = info['aggregate']
    ground_truths = [gt for gt in info['ground_truths']]
    break


In [5]:
bs = 32
net = model.models.Baseline(1025, bs, seq_len=173, num_sources=2)

In [6]:

# print(aggregate[:, 0])
# agg = torch.t(aggregate)
# print(agg[0, :])
print(aggregate.size())
print(aggregate[1, 1, :])
seq_len = aggregate.size()[1]
aggs = [aggregate[:, t, :] for t in range(seq_len)]
x = torch.cat(aggs).view(173, bs, -1)
print(x[1, 1, :])
print(x.size())

torch.Size([32, 173, 1025])
tensor([1.4197e-19, 9.8981e-19, 1.4690e-17,  ..., 8.6916e-19, 1.9786e-18,
        1.6256e-18])
tensor([1.4197e-19, 9.8981e-19, 1.4690e-17,  ..., 8.6916e-19, 1.9786e-18,
        1.6256e-18])
torch.Size([173, 32, 1025])


In [34]:
def reshape(x, seq_len, bs):
    x = [x[:, ts, :] for ts in range(seq_len)]
    x = torch.cat(x).view(seq_len, bs, -1)
    return x


def calc_dists(preds, gts):
    """
    Args:
        preds: n_sources * [seq_len, bs, input_dim]
        gts: n_sources * [seq_len, bs, input_dim]
    
    Returns:
        dists: [bs, n_sources]
    """
    n_sources = len(preds)
    assert n_sources == len(gts)

    bs = preds[0].size()[1]
    # TODO: greedy assignment
    dists = torch.zeros(bs, n_sources)

    for src_id in range(n_sources):
        pred = preds[src_id]
        gt = gts[src_id]
#         print(pred.size())
#         print(gt.size())
        for batch in range(bs):
            dist = torch.norm(torch.squeeze(pred[:, batch, :] - gt[:, batch, :], dim=1), 2)
            dists[batch, src_id] = dist
        
    return dists


class MinLoss(nn.Module):
    """Custom loss function #1

    Compare the distance from output with its closest ground truth.

    """
    def __init__(self):
        # nn.Module.__init__(self)
        super(MinLoss, self).__init__()


    def forward(self, predictions, ground_truths):
        """
        Args:
            prediction: num_sources * [seq_len, bs, input_dim]
            ground_truths: num_sources * [bs, seq_len, input_dim] 
        Returns:
            loss: [bs,]
        """
        seq_len = predictions[0].size()[0]
        bs = predictions[0].size()[1]
        # reshape gts into seq_len, bs, input_dim
        gts = [reshape(gt, seq_len, bs) for gt in ground_truths]

        # get distance measure (bs * num_sources)
        dists = calc_dists(predictions, gts)
        
        loss = torch.sum(dists)
        
        return loss


In [22]:
import numpy as np
agg = np.load('./new_spect.npy')

def concat(m):
    num_features, nrows, ncols = m.shape
    result = np.zeros((nrows*num_features, ncols))

    for i in range(num_features):
        start = i * nrows;
        end = (i + 1) * nrows;
        result[start:end, :] = agg[i]
    return result

result = concat(agg)

[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]
[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]


In [24]:
preds = net(aggregate)

In [25]:
preds[0].size()

torch.Size([173, 32, 1025])

In [37]:
criterion = MinLoss()

In [38]:
loss = criterion(preds, ground_truths)

In [39]:
loss

tensor(797412.7500, grad_fn=<SumBackward0>)

In [40]:
loss.backward()

In [45]:
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer.step()
preds2 = net(aggregate)

In [46]:
loss2 = criterion(preds2, ground_truths)
loss2

tensor(796459.9375, grad_fn=<SumBackward0>)

In [180]:
# each batch has a distance matrix [d * d]
# number in each cell represents rank of each cell after sorting
# use (?) algorithm to pick d cells
import numpy as np
dists = np.array([[[1, 3, 4],
                   [2, 5, 6],
                   [7, 8, 9]],
                  [[5, 9, 8],
                   [4, 2, 6],
                   [3, 1, 7]]])
print(dists)

def get_orders(dists):
    num_batches = dists.shape[0]
    d = dists.shape[1]
    orders = np.zeros(dists.shape)
    for batch in range(num_batches):
        flattened = np.copy(dists[batch, :, :].reshape(1, -1))
        indices = np.argsort(flattened)
        flattened[:, indices] = np.arange(flattened.size)
        orders[batch, :, :] = flattened.reshape(dists.shape[1], -1)

#         # mask diagonal
#         np.fill_diagonal(orders[batch, :, :], np.ones(d) * (np.max(dists) + 1))
    
    return orders.astype(int)

def get_seq(orders):
    num_batches = dists.shape[0]
    d = dists.shape[1]
    mask = np.max(orders) + 1
    for i in range(d):
        for batch in range(num_batches):
            m = np.argmin(orders[batch, :, :])
#             print(m)
            indices = np.array([(m // d), (m % d)])
            print(indices)
#             print("orders:")
#             print(orders)

            # mask row & col
            orders[batch, indices[0], :] = np.ones(d) * mask
            orders[batch, :, indices[1]] = np.ones(d) * mask
    # return a list of pairs

orders = get_orders(dists)
print(orders)
get_seq(orders)

[[[1 3 4]
  [2 5 6]
  [7 8 9]]

 [[5 9 8]
  [4 2 6]
  [3 1 7]]]
[[[0 2 3]
  [1 4 5]
  [6 7 8]]

 [[4 8 7]
  [3 1 5]
  [2 0 6]]]
[0 0]
[2 1]
[1 1]
[1 0]
[2 2]
[0 2]


In [16]:
print(len(ground_truths))
print(ground_truths[0].size())
gt = net.reshape(ground_truths[0])
print(torch.squeeze(gt[:, 0, :], dim=1).size())

2
torch.Size([32, 173, 1025])
torch.Size([173, 1025])
