In [17]:
# from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding
import torch.nn as nn
import torch
import math

## Tokenized Text input

In [9]:
vocab_size = 128
batch_size = 3
max_length = 184
x = torch.randint(vocab_size, (batch_size, max_length))
x_lens = torch.tensor([163,184,152])


torch.Size([3, 184])


In [40]:
# audio features
batch_size = 3
encodec_size = 8
max_y_lens = 775
y = torch.randn((batch_size, max_y_lens, encodec_size))*800
y_lens = torch.tensor([770,775,768])

In [10]:
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """
    Args:
      lengths:
        A 1-D tensor containing sentence lengths.
      max_len:
        The length of masks.
    Returns:
      Return a 2-D bool tensor, where masked positions
      are filled with `True` and non-masked positions are
      filled with `False`.

    >>> lengths = torch.tensor([1, 3, 2, 5])
    >>> make_pad_mask(lengths)
    tensor([[False,  True,  True,  True,  True],
            [False, False, False,  True,  True],
            [False, False,  True,  True,  True],
            [False, False, False, False, False]])
    """
    assert lengths.ndim == 1, lengths.ndim
    max_len = max(max_len, lengths.max())
    n = lengths.size(0)

    expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)

    return expaned_lengths >= lengths.unsqueeze(1)

In [14]:
x_mask = make_pad_mask(x_lens)
print(x_mask[0])

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [4]:
class TokenEmbedding(nn.Module):
    def __init__(
        self,
        dim_model: int,
        vocab_size: int,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.dim_model = dim_model

        self.dropout = torch.nn.Dropout(p=dropout)
        self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)

    @property
    def weight(self) -> torch.Tensor:
        return self.word_embeddings.weight

    def embedding(self, index: int) -> torch.Tensor:
        return self.word_embeddings.weight[index : index + 1]

    def forward(self, x: torch.Tensor):
        X = self.word_embeddings(x)
        X = self.dropout(X)

        return X

In [5]:
cfg = {'dim_model': 128, 'vocab_size':128, 'dropout':0.0}
t = TokenEmbedding(**cfg)

In [16]:
embedded = t(x)
print(x.shape, embedded.shape)

torch.Size([3, 184]) torch.Size([3, 184, 128])


In [18]:
class SinePositionalEmbedding(nn.Module):
    def __init__(
        self, dim_model: int, dropout: float = 0.0, scale: bool = False
    ):
        super().__init__()
        self.dim_model = dim_model
        self.x_scale = math.sqrt(dim_model) if scale else 1.0
        self.alpha = nn.Parameter(torch.ones(1))
        self.dropout = torch.nn.Dropout(p=dropout)

        self.reverse = False
        self.pe = None
        self.extend_pe(torch.tensor(0.0).expand(1, 4000))

    def extend_pe(self, x):
        """Reset the positional encodings."""
        if self.pe is not None:
            if self.pe.size(1) >= x.size(1):
                if self.pe.dtype != x.dtype or self.pe.device != x.device:
                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
                return
        pe = torch.zeros(x.size(1), self.dim_model)
        if self.reverse:
            position = torch.arange(
                x.size(1) - 1, -1, -1.0, dtype=torch.float32
            ).unsqueeze(1)
        else:
            position = torch.arange(
                0, x.size(1), dtype=torch.float32
            ).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.dim_model, 2, dtype=torch.float32)
            * -(math.log(10000.0) / self.dim_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe.to(device=x.device, dtype=x.dtype).detach()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.extend_pe(x)
        output = x.unsqueeze(-1) if x.ndim == 2 else x
        output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
        return self.dropout(output)


In [21]:
dim_model = 128
pos = SinePositionalEmbedding(dim_model=dim_model)
positioned = pos(embedded)
print(positioned.shape)

torch.Size([3, 184, 128])


In [38]:
print(y)

tensor([[[ 2.4003e-01, -1.1978e-01,  6.9747e-01,  ...,  7.1296e-01,
          -1.2161e+00,  7.1315e-02],
         [-6.4140e-01,  4.0792e-01,  2.2685e-01,  ..., -6.5513e-01,
          -1.3306e+00,  1.5161e+00],
         [ 1.3071e+00,  6.3669e-01, -1.7437e-01,  ...,  1.6955e-01,
          -1.2495e+00, -1.7102e-01],
         ...,
         [-4.2391e-01, -1.5363e-03, -3.8377e-01,  ..., -1.8966e+00,
           4.3150e-01, -2.7159e+00],
         [-1.1792e-01,  2.4326e-02,  8.5095e-01,  ...,  9.6202e-01,
          -9.1284e-01,  1.3664e+00],
         [ 1.2822e-01,  4.0712e-01, -4.7922e-01,  ..., -9.1959e-01,
          -1.0553e-01, -7.6964e-01]],

        [[ 4.9819e-01, -1.2775e-01,  1.1892e+00,  ...,  9.5521e-01,
          -9.6966e-02, -2.2252e+00],
         [ 3.4125e-01,  1.3086e+00,  2.9771e-01,  ...,  2.6346e-01,
           5.6558e-01, -4.2933e-01],
         [-1.6659e+00,  1.6146e-01, -1.7946e+00,  ..., -1.3131e+00,
           2.4257e+00,  5.1549e-01],
         ...,
         [-1.6473e+00, -1

In [41]:
y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64)
print(y_mask_int)
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
print(codes, codes.shape)


tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 1, 1, 1]])
tensor([[[ 1032,   -45,   575,  ...,    78,  -448,   116],
         [ -176,  2239,     7,  ...,  -381, -1074,  1045],
         [  957,   176,   110,  ..., -1003,  1221,   868],
         ...,
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0],
         [    0,     0,     0,  ...,     0,     0,     0]],

        [[  175,   -55,  1597,  ...,  -161,   380,  1447],
         [  291,   454,    56,  ...,   692,  -814,   294],
         [  715,  1215,    -5,  ...,   543,  -600,   827],
         ...,
         [ -500,  -207,    21,  ...,  -381,  1081,   266],
         [   98,    15,   216,  ..., -1397,   359,  1247],
         [  124,   567,  1259,  ..., -1200,  -667,   334]],

        [[  270,   -29,   840,  ...,  -505,   206,  1017],
         [ -210,     4,   449,  ...,  -936,  -242,  -714],
         [ -118,   811,    59,  ..., -1

In [42]:
codes[0,:5,:]


tensor([[ 1032,   -45,   575,  -696,  -959,    78,  -448,   116],
        [ -176,  2239,     7,   333,   -39,  -381, -1074,  1045],
        [  957,   176,   110,  -136, -2441, -1003,  1221,   868],
        [ 1454,   164, -2637,  -119,  -365,  -957,  -293,   813],
        [-1610,   248,   414, -1230,   197,    71,   -52,   519]])