In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!cp /content/drive/MyDrive/autovc/autovc_dataset.py /content/autovc_dataset.py
!cp /content/drive/MyDrive/autovc/autovc_model.py /content/autovc_model.py
!cp /content/drive/MyDrive/autovc/autovc_network.py /content/autovc_network.py
!cp /content/drive/MyDrive/autovc/autovc_vcoder.py /content/autovc_vcoder.py
!cp /content/drive/MyDrive/autovc/hparams.py /content/hparams.py


In [3]:
import torch
import os
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from autovc_model import AutovcModel
from autovc_dataset import get_dynamic_loader
import torch.multiprocessing as mp

mp.set_start_method('spawn')


In [4]:
root_path = '/content/drive/MyDrive/autovc/data_unzip'
total_iterations = 100_000
batch_size = 16

output_path = '/content/drive/MyDrive/autovc/output'
os.makedirs(output_path, exist_ok=True)
save_freq = 10000

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutovcModel(device=device)


In [6]:
train_loader = get_dynamic_loader(root_dir = root_path, batch_size=batch_size, encoder=model.speaker_encoder, num_workers= 4)

In [7]:
parameters_G = model.G.parameters()
optimizer_G = torch.optim.Adam(parameters_G, lr=0.0005, betas=(0.0, 0.999))
parameters_D = list(model.D_1.parameters()) + list(model.D_1.parameters())
optimizer_D = torch.optim.Adam(parameters_D, lr=0.0001, betas=(0.0, 0.999))


  return disable_fn(*args, **kwargs)


In [8]:

mse_loss = nn.MSELoss()
mae_loss = nn.L1Loss()

In [9]:
data_iter = iter(train_loader)
pbar = tqdm(range(total_iterations), desc="Training", ncols=100)

for step in pbar:
    for interval in range(2):
        pbar.set_description(f"Step {step} | {'D' if interval else 'G'}")

        try:
            x_real, emb_org = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            x_real, emb_org = next(data_iter)

        x_real = x_real.to(device)
        emb_org = emb_org.to(device)

        if step % 2 == 0:
            # 自我重建
            emb_trg = emb_org
        else:
            # 選擇不同 speaker
            rand_idx = torch.randperm(emb_org.size(0))
            emb_trg = emb_org[rand_idx]

        if interval == 1:  # 訓練 D
            model.G.eval()
            model.D_1.train()
            model.D_2.train()

            with torch.no_grad():
                _, x_fake, _ = model.G(x_real, emb_org, emb_trg)

            x_real_d = x_real.unsqueeze(1) if x_real.dim() == 3 else x_real
            x_fake_d = x_fake.unsqueeze(1) if x_fake.dim() == 3 else x_fake

            pred_fake_1 = model.D_1(x_fake_d.detach())
            pred_fake_2 = model.D_2(x_fake_d.detach())
            pred_real_1 = model.D_1(x_real_d)
            pred_real_2 = model.D_2(x_real_d)

            loss_D_fake = F.mse_loss(pred_fake_1, torch.zeros_like(pred_fake_1)) + \
                            F.mse_loss(pred_fake_2, torch.zeros_like(pred_fake_2))
            loss_D_real = F.mse_loss(pred_real_1, torch.ones_like(pred_real_1)) + \
                            F.mse_loss(pred_real_2, torch.ones_like(pred_real_2))
            loss_D = loss_D_fake + loss_D_real

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

        else:  # 訓練 G
            model.G.train()
            model.D_1.eval()
            model.D_2.eval()

            x_identic, x_identic_psnt, code_real = model.G(x_real, emb_org, emb_trg)

            x_identic_d = x_identic.unsqueeze(1) if x_identic.dim() == 3 else x_identic
            x_identic_psnt_d = x_identic_psnt.unsqueeze(1) if x_identic_psnt.dim() == 3 else x_identic_psnt
            x_real_d = x_real.unsqueeze(1) if x_real.dim() == 3 else x_real

            # GAN loss
            with torch.no_grad():
                pred_fake_1 = model.D_1(x_identic_psnt_d)
                pred_fake_2 = model.D_2(x_identic_psnt_d)
                loss_G_GAN = F.mse_loss(pred_fake_1, torch.ones_like(pred_fake_1)) + \
                            F.mse_loss(pred_fake_2, torch.ones_like(pred_fake_2))

            # Identity losses
            loss_G_id = F.mse_loss(x_real, x_identic.squeeze(1))
            loss_G_id_psnt = F.mse_loss(x_real, x_identic_psnt.squeeze(1))

            # Code semantic consistency
            code_reconst = model.G(x_identic_psnt, emb_org, None)
            loss_G_cd = F.l1_loss(code_real, code_reconst)

            # 預設
            n_layers_D = 5
            num_D = 2
            feat_weights = 5.0 / (n_layers_D + 1)
            D_weights = 1.0 / num_D
            loss_G_Feat = 0


            with torch.no_grad():
              real_feat_1 = model.D_1.get_feature(x_real_d)
              fake_feat_1 = model.D_1.get_feature(x_identic_psnt_d)
              real_feat_2 = model.D_2.get_feature(x_real_d)
              fake_feat_2 = model.D_2.get_feature(x_identic_psnt_d)

            fea_real = [real_feat_1, real_feat_2]
            fea_fake = [fake_feat_1, fake_feat_2]

            for i in range(num_D):
                for j in range(0, len(fea_fake[i]) - 1):
                    loss_G_Feat += D_weights * feat_weights * F.l1_loss(fea_fake[i][j], fea_real[i][j].detach())

            # 將其加到總 loss 中
            loss_G = loss_G_GAN + loss_G_id*10 + loss_G_id_psnt*10 + loss_G_cd + loss_G_Feat * 10.0


            if step % 2 == 0:
                loss_G_Rec = F.l1_loss(x_identic_psnt,  x_real.unsqueeze(1)) * 10.0
                loss_G += loss_G_Rec

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

        # 更新進度條
        pbar.set_postfix({
            'loss_G': round(loss_G.item(), 4) if interval == 0 else '-',
            'loss_D': round(loss_D.item(), 4) if interval == 1 else '-'
        })

        if step % save_freq == 0 and interval == 1 and step > 1:
            torch.save({
                'step': step,
                'G': model.G.state_dict(),
                'D_1': model.D_1.state_dict(),
                'D_2': model.D_2.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict()
            }, f'{output_path}/autovc_model_step_{step}.pth')
            print(f"[Step {step}] Model checkpoint saved.")


Step 4796 | G:   5%|▌           | 4796/100000 [20:39:03<409:56:06, 15.50s/it, loss_G=-, loss_D=1.19]


KeyboardInterrupt: 