Skip to content
This repository has been archived by the owner on Nov 11, 2023. It is now read-only.

Commit

Permalink
Updata VITS2 part (Transformer Flow)
Browse files Browse the repository at this point in the history
  • Loading branch information
ylzz1997 committed Aug 1, 2023
1 parent 39b0bef commit fc8336f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 6 deletions.
4 changes: 3 additions & 1 deletion configs_template/config_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3,
"n_layers_trans_flow": 3,
"n_flow_layer": 4,
"use_spectral_norm": false,
"gin_channels": 768,
Expand All @@ -65,7 +66,8 @@
"vol_embedding":false,
"use_depthwise_conv":false,
"flow_share_parameter": false,
"use_automatic_f0_prediction": true
"use_automatic_f0_prediction": true,
"use_transformer_flow": false
},
"spk": {
"nyaru": 0,
Expand Down
4 changes: 3 additions & 1 deletion configs_template/config_tiny_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"upsample_initial_channel": 400,
"upsample_kernel_sizes": [16,16, 4, 4, 4],
"n_layers_q": 3,
"n_layers_trans_flow": 3,
"n_flow_layer": 4,
"use_spectral_norm": false,
"gin_channels": 768,
Expand All @@ -65,7 +66,8 @@
"vol_embedding":false,
"use_depthwise_conv":true,
"flow_share_parameter": true,
"use_automatic_f0_prediction": true
"use_automatic_f0_prediction": true,
"use_transformer_flow": false
},
"spk": {
"nyaru": 0,
Expand Down
48 changes: 47 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,46 @@ def forward(self, x, x_mask, g=None, reverse=False):
x = flow(x, x_mask, g=g, reverse=reverse)
return x

class TransformerCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
share_parameter=False
):

super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels

self.flows = nn.ModuleList()

self.wn = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = self.gin_channels) if share_parameter else None

for i in range(n_flows):
self.flows.append(
modules.TransformerCouplingLayer(channels, hidden_channels, kernel_size, n_layers, n_heads, p_dropout, filter_channels, mean_only=True, wn_sharing_parameter=self.wn, gin_channels = self.gin_channels))
self.flows.append(modules.Flip())

def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x


class Encoder(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -327,6 +367,8 @@ def __init__(self,
use_automatic_f0_prediction = True,
flow_share_parameter = False,
n_flow_layer = 4,
n_layers_trans_flow = 3,
use_transformer_flow = False,
**kwargs):

super().__init__()
Expand All @@ -351,6 +393,7 @@ def __init__(self,
self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.use_depthwise_conv = use_depthwise_conv
self.use_automatic_f0_prediction = use_automatic_f0_prediction
self.n_layers_trans_flow = n_layers_trans_flow
if vol_embedding:
self.emb_vol = nn.Linear(1, hidden_channels)

Expand Down Expand Up @@ -392,7 +435,10 @@ def __init__(self,
self.dec = Generator(h=hps)

self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
if use_transformer_flow:
self.flow = TransformerCouplingBlock(inter_channels, hidden_channels, filter_channels, n_heads, n_layers_trans_flow, 5, p_dropout, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
else:
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels, share_parameter= flow_share_parameter)
if self.use_automatic_f0_prediction:
self.f0_decoder = F0Decoder(
1,
Expand Down
22 changes: 19 additions & 3 deletions modules/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from torch.nn import functional as F

import modules.commons as commons
from modules.DSConv import weight_norm_modules
from modules.modules import LayerNorm


class FFT(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
proximal_bias=False, proximal_init=True, **kwargs):
proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
Expand All @@ -20,7 +21,11 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init

if isflow:
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
self.gin_channels = kwargs["gin_channels"]
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
Expand All @@ -35,14 +40,25 @@ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
self.norm_layers_1.append(LayerNorm(hidden_channels))

def forward(self, x, x_mask):
def forward(self, x, x_mask, g = None):
"""
x: decoder input
h: encoder output
"""
if g is not None:
g = self.cond_layer(g)

self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
x = x * x_mask
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
x = commons.fused_add_tanh_sigmoid_multiply(
x,
g_l,
torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
Expand Down
50 changes: 50 additions & 0 deletions modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
from torch.nn import functional as F

import modules.attentions as attentions
import modules.commons as commons
from modules.commons import get_padding, init_weights
from modules.DSConv import (
Expand Down Expand Up @@ -304,3 +305,52 @@ def forward(self, x, x_mask, g=None, reverse=False):
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x

class TransformerCouplingLayer(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout=0,
filter_channels=0,
mean_only=False,
wn_sharing_parameter=None,
gin_channels = 0
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only

self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, isflow = True, gin_channels = gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()

def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
else:
m = stats
logs = torch.zeros_like(m)

if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x

0 comments on commit fc8336f

Please sign in to comment.