Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust code as per original iSTFTNet #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions config_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
"lr_decay": 0.999,
"seed": 1234,

"upsample_rates": [8,8,2],
"upsample_kernel_sizes": [16,16,4],
"upsample_rates": [8,8],
"upsample_kernel_sizes": [16,16],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
Expand Down
43 changes: 20 additions & 23 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, h):
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))

self.out_proj_x1 = weight_norm(Conv1d(h.upsample_initial_channel // 4, 1, 7, 1, padding=3))
self.out_proj_x2 = weight_norm(Conv1d(h.upsample_initial_channel // 8, 1, 7, 1, padding=3))
#self.out_proj_x2 = weight_norm(Conv1d(h.upsample_initial_channel // 8, 1, 7, 1, padding=3))

def forward(self, x):
x = self.conv_pre(x)
Expand All @@ -117,15 +117,15 @@ def forward(self, x):
x = xs / self.num_kernels
if i == 1:
x1 = self.out_proj_x1(x)
elif i == 2:
x2 = self.out_proj_x2(x)
# elif i == 2:
# x2 = self.out_proj_x2(x)
x = F.leaky_relu(x)
x = self.reflection_pad(x)
x = self.conv_post(x)
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])

return spec, phase, x1, x2
return spec, phase, x1 #, x2

def remove_weight_norm(self):
print('Removing weight norm...')
Expand Down Expand Up @@ -272,7 +272,7 @@ def __init__(self, kernels, channels, groups, strides):
self.pqmf_2 = PQMF(N=2, taps=256, cutoff=0.25, beta=10.0)
self.pqmf_4 = PQMF(N=4, taps=192, cutoff=0.13, beta=10.0)

def forward(self, x, x_hat, x2_hat, x1_hat):
def forward(self, x, x_hat, x1_hat):
y = []
y_hat = []
fmap = []
Expand All @@ -286,19 +286,19 @@ def forward(self, x, x_hat, x2_hat, x1_hat):
y_hat.append(p3_hat)
fmap_hat.append(p3_fmap_hat)

x2_ = self.pqmf_2(x)[:, :1, :] # Select first band
#x2_ = self.pqmf_2(x)[:, :1, :] # Select first band
x1_ = self.pqmf_4(x)[:, :1, :] # Select first band

x2_hat_ = self.pqmf_2(x_hat)[:, :1, :]
#x2_hat_ = self.pqmf_2(x_hat)[:, :1, :]
x1_hat_ = self.pqmf_4(x_hat)[:, :1, :]

p2_, p2_fmap_ = self.combd_2(x2_)
y.append(p2_)
fmap.append(p2_fmap_)
# p2_, p2_fmap_ = self.combd_2(x2_)
# y.append(p2_)
# fmap.append(p2_fmap_)

p2_hat_, p2_fmap_hat_ = self.combd_2(x2_hat)
y_hat.append(p2_hat_)
fmap_hat.append(p2_fmap_hat_)
# p2_hat_, p2_fmap_hat_ = self.combd_2(x2_hat)
# y_hat.append(p2_hat_)
# fmap_hat.append(p2_fmap_hat_)

p1_, p1_fmap_ = self.combd_1(x1_)
y.append(p1_)
Expand All @@ -309,16 +309,13 @@ def forward(self, x, x_hat, x2_hat, x1_hat):
fmap_hat.append(p1_fmap_hat_)





p2, p2_fmap = self.combd_2(x2_)
y.append(p2)
fmap.append(p2_fmap)

p2_hat, p2_fmap_hat = self.combd_2(x2_hat_)
y_hat.append(p2_hat)
fmap_hat.append(p2_fmap_hat)
# p2, p2_fmap = self.combd_2(x2_)
# y.append(p2)
# fmap.append(p2_fmap)
#
# p2_hat, p2_fmap_hat = self.combd_2(x2_hat_)
# y_hat.append(p2_hat)
# fmap_hat.append(p2_fmap_hat)

p1, p1_fmap = self.combd_1(x1_)
y.append(p1)
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def train(rank, a, h):
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y = y.unsqueeze(1)
# y_g_hat = generator(x)
spec, phase, x1, x2 = generator(x)
spec, phase, x1 = generator(x)

y_g_hat = stft.inverse(spec, phase)

Expand All @@ -135,7 +135,7 @@ def train(rank, a, h):
optim_d.zero_grad()

# MPD
y_df_hat_r, y_df_hat_g, _, _ = mcmbd(y, y_g_hat.detach(), x2.detach(), x1.detach())
y_df_hat_r, y_df_hat_g, _, _ = mcmbd(y, y_g_hat.detach(), x1.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)

# MSD
Expand All @@ -153,7 +153,7 @@ def train(rank, a, h):
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45

y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mcmbd(y, y_g_hat, x2, x1)
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mcmbd(y, y_g_hat, x1)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msbd(y, y_g_hat)
loss_fm_f = 2 * feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = 2 * feature_loss(fmap_s_r, fmap_s_g)
Expand Down Expand Up @@ -201,7 +201,7 @@ def train(rank, a, h):
for j, batch in enumerate(validation_loader):
x, y, _, y_mel = batch
# y_g_hat = generator(x.to(device))
spec, phase, x1, x2 = generator(x.to(device))
spec, phase, x1 = generator(x.to(device))

y_g_hat = stft.inverse(spec, phase)

Expand Down