In [1]:
import torch
from torch import nn

In [2]:
class ReferenceEncoder(nn.Module):
    def __init__(
        self,
        batch_size,
        seq_len = 16000,
        n_mel = 80,
        conv_channels1 = 32,
        conv_channels2 = 32,
        conv_channels3 = 64,
        conv_channels4 = 64,
        conv_channels5 = 128,
        conv_channels6 = 128,
        n_unit = 128
    ):
        super().__init__()
        conv_channels_list = [
            conv_channels1,
            conv_channels2,
            conv_channels3,
            conv_channels4,
            conv_channels5,
            conv_channels6,
        ]
        
        self.convs = nn.ModuleList()
        self.batch_size = batch_size
        for layer in range(6):
            seq_len = (seq_len + 1) // 2
            n_mel = (n_mel + 1) // 2
            
            in_channels = 1 if layer == 0 else conv_channels_list[layer-1]
            out_channels = conv_channels_list[layer]
            self.convs += [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
            ]
        self.convs = nn.Sequential(*self.convs)
        
        n_mel *= conv_channels_list[-1]
        self.gru = nn.GRU(n_mel, n_unit, batch_first=True)
        
    def forward(self, seqs):
        out = self.convs(seqs) # (B, 1, seq_len, n_mel) -> (B, channels, new_seq_len, new_n_mel)
        out = torch.squeeze(out.reshape(self.batch_size, 1, out.shape[2], -1)) # (B, channels, new_seq_len, new_n_mel) -> (B, new_seq_len, channels * new_n_mel)
        out = self.gru(out) # (B, new_seq_len, channels * new_n_mel) -> (B, new_seq_len, n_unit)
        return out[0], out[1]

In [3]:
ref = ReferenceEncoder(10)
input = torch.zeros(10, 1, 16000, 80)
ref_output, hidden_state = ref(input)
print(ref_output.shape)

torch.Size([10, 250, 128])


In [4]:
# https://qiita.com/m__k/items/646044788c5f94eadc8d

class StyleTokenLayer(nn.Module):
    def __init__(
        self,
        batch_size,
        emb_size = 256,
        n_tokens = 10,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.tanh = nn.Tanh()
        self.emb_size = emb_size,
        self.softmax = nn.Softmax(dim=2) # n_token方向にsoftmaxを取りたい
        self.tokens = torch.randn(n_tokens, emb_size)
    
    def forward(self, query):
        query = torch.cat((query, query), 2) # この実装があってるかどうかわからない, ReferenceEncoderの出力次元が128でtokenのembeddingの次元が256なのでqueryを2つ重ねている
        tokens = self.tanh(self.tokens)
        key = tokens.expand(self.batch_size, *tokens.shape) # (n_token, emb_size) -> (B, n_token, emb_size)
        value = key
        
        key = torch.transpose(key, 1, 2) # (B, n_token, emb_size) -> (B, emb_size, n_token)
        s = torch.bmm(query, key) / self.emb_size[0]**(1/2) # (B, seq_len, emb_size) @ (B, emb_size, n_token) -> (B, seq_len, n_token)  論文ではcosine similarityを用いていたが実装の簡単のためdot-productにした
        attention_weight = self.softmax(s) # (B, seq_len, n_token) -> (B, seq_len, n_token)
        out = torch.bmm(attention_weight, value) #(B, seq_len, n_token) @ (B, n_token, emb_size) -> (B, seq_len, emb_size)
        
        return out, attention_weight

In [6]:
style = StyleTokenLayer(10)
style_output, attention_weight = style(ref_output)
print(style_output.shape)

torch.Size([10, 250, 256])
