In [1]:
import torch
from i6_experiments.users.rilling.experiments.librispeech.librispeech_joint_training_given_alignments.pytorch_networks.shared.commons import sequence_mask
# Assuming you have your 2D tensor of phoneme sequences and 3D tensor of encodings
# Let's call them phoneme_sequences and encodings respectively

# Phoneme sequences tensor shape: [batch_size, sequence_length]
# Encodings tensor shape: [batch_size, sequence_length, encoding_size]


# Calculate the mean encoding for every phoneme across all batches
def encoding_distance_loss(phoneme_sequences, encodings, seq_lenghts):
    # Reshape the encodings tensor to [batch_size * sequence_length, encoding_size]
    mask = sequence_mask(seq_lenghts, phoneme_sequences.shape[-1])
    phoneme_sequences_masked = ((phoneme_sequences + 1) * mask) - 1
    encodings_masked = ((encodings + 1) * mask) - 1

    encodings_flat = encodings_masked.view(-1, encodings_masked.size(-1))

    # Flatten the phoneme sequences tensor to match the encodings tensor
    phoneme_sequences_flat = phoneme_sequences_masked.view(-1)

    phonemes, inverse, counts = torch.unique(phoneme_sequences_flat, return_counts=True, return_inverse=True)

    # Initialize a tensor to store the sum of encodings for each phoneme
    sum_encodings = torch.zeros(phonemes.shape[0], encodings.size(-1))

    # Use index_add to collect the sum of all phoneme encodings in the current batch
    sum_encodings = sum_encodings.index_add(0, inverse, encodings_flat)

    mean_encodings = (sum_encodings / counts.unsqueeze(1)).unsqueeze(0)

    loss = -1 * torch.cdist(mean_encodings, mean_encodings).sum()
    return loss


# Example usage
batch_size = 2
sequence_lengths = torch.Tensor([4, 5])
encoding_size = 5

# Create example tensors
phoneme_sequences = torch.tensor([[0, 1, 5, 2, 0, 0], [3, 0, 2, 5, 6, 0]], dtype=torch.float32, requires_grad=True)  # Example phoneme sequences
print(phoneme_sequences.shape[-1])
encodings = torch.randn((batch_size, int(phoneme_sequences.shape[-1]), encoding_size), requires_grad=True)  # Example encodings

# Calculate the mean encoding for every phoneme
mean_encodings, distances = encoding_distance_loss(phoneme_sequences, encodings, sequence_lengths)

print("Encodings: \n", encodings)
print("Mean encodings shape:\n", mean_encodings.shape)
print("Mean encodings:\n", mean_encodings)
print("Encoding for 3: \n", encodings[1, 0])
print("Mean encoding for 3:\n", mean_encodings[3])
print("Distance: ", distances)
distances.sum().backward()

6


RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 2

In [32]:
(encodings[0, 0] + encodings[1, 1])/2

tensor([ 0.9826,  1.0340,  0.7152,  0.6475, -0.7005])

In [64]:
test = torch.Tensor([0, 1, 2, 0, 6, 8, 8])
mask = torch.Tensor([1, 1, 1, 1, 0, 0, 0])
test_masked = ((test + 1) * mask) - 1
test_masked

tensor([ 0.,  1.,  2.,  0., -1., -1., -1.])