In [1]:
import torch
import torch.nn as nn
from model import Song2Vec
import torch.nn.functional as F

In [2]:
cnn_input_channels = 1  # Assuming the input is a 1D array
cnn_output_channels = 512  # Number of output channels after CNN
transformer_input_dim = cnn_output_channels  # Should match the output channels from the CNN
embed_dim = 16  # Output embedding size for transformer
num_heads = 1
num_layers = 1
ff_dim = 16

model = Song2Vec(cnn_input_channels, cnn_output_channels, transformer_input_dim, embed_dim, num_heads, num_layers, ff_dim)



In [3]:
from model import CNNFeatureExtractor

cnn = CNNFeatureExtractor(3, 128)

In [4]:
cnn(torch.randn(1, 3, 1024, 2048)).shape

torch.Size([1, 1024, 512])

In [3]:
model(torch.randn(1, 3, 1024, 2048)).shape

torch.Size([1, 16])

In [4]:
triplet_loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)

In [16]:
anchor, positive, negative = torch.randn(1, 3, 1024, 2048), torch.randn(1, 4, 3, 1024, 2048), torch.randn(1, 4, 3, 1024, 2048)
        
shape = positive.shape

# reshape positive and negative to (batch_size * 20, n_channels, height, width)
positive = positive.view(-1, *positive.shape[2:])
negative = negative.view(-1, *negative.shape[2:])

anchor_embed = model(anchor)
with torch.no_grad():
    positive_embed = model(positive) # shape (batch_size * 20, embed_dim)
    negative_embed = model(negative)

print("embedding shape: ", positive_embed.shape)

# reshape positive_embed and negative_embed to (batch_size, 20, embed_dim)
positive_embed = positive_embed.view(shape[0], shape[1], -1)
negative_embed = negative_embed.view(shape[0], shape[1], -1)

dist_pos = F.cosine_similarity(anchor_embed.unsqueeze(1), positive_embed, dim=-1) # shape (batch_size, 20)
dist_neg = F.cosine_similarity(anchor_embed.unsqueeze(1), negative_embed, dim=-1)

p = torch.argmin(dist_pos, dim=1) # shape (batch_size,)
n = torch.argmax(dist_neg, dim=1)

# reshape positive and negative back to (batch_size, 20, n_channels, height, width)
positive = positive.view(shape[0], shape[1], *positive.shape[1:])
negative = negative.view(shape[0], shape[1], *negative.shape[1:])

# recompute positive and negative embeddings with gradients
positive_embed = model(positive[torch.arange(shape[0]), p]) # shape (batch_size, embed_dim)
negative_embed = model(negative[torch.arange(shape[0]), n])

loss = triplet_loss_fn(anchor_embed, positive_embed, negative_embed)

embedding shape:  torch.Size([4, 16])
p:  torch.Size([1])
torch.Size([1, 3, 1024, 2048])
