In [None]:
import torch
import os
import random
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np
from tqdm import tqdm
from autovc_model import AutovcModel
from autovc_dataset import get_dynamic_loader



In [None]:
root_path = 'data'
total_iterations = 500_000
batch_size = 16

output_path = 'output'
os.makedirs(output_path, exist_ok=True)
save_freq = 10000

In [None]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

model = AutovcModel(device=device)


In [None]:
train_loader = get_dynamic_loader(root_path = root_path, batch_size=batch_size, encoder=model.speaker_encoder)

Processing Swapping dataset from data/train ...
Finished processing 8631 identity folders.


In [None]:
parameters_G = model.G.parameters()
optimizer_G = torch.optim.Adam(parameters_G, lr=0.0005, betas=(0.0, 0.999))
parameters_D = list(model.D_1.parameters()) + list(model.D_1.parameters())
optimizer_D = torch.optim.Adam(parameters_D, lr=0.0001, betas=(0.0, 0.999))


In [None]:
randindex = [i for i in range(batch_size)]
random.shuffle(randindex)
mse_loss = nn.MSELoss()
mae_loss = nn.L1Loss()

In [None]:
data_iter = iter(train_loader)
pbar = tqdm(range(total_iterations), desc="Training", ncols=100)

for step in pbar:
    for interval in range(2):
        pbar.set_description(f"Step {step} | {'D' if interval else 'G'}")

        try:
            x_real, emb_org = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            x_real, emb_org = next(data_iter)

        x_real = x_real.to(device)
        emb_org = emb_org.to(device)

        if step % 2 == 0:
            # 自我重建
            emb_trg = emb_org
        else:
            # 選擇不同 speaker
            rand_idx = torch.randperm(emb_org.size(0))
            emb_trg = emb_org[rand_idx]

        if interval == 1:  # 訓練 D
            model.G.eval()
            model.D_1.train()
            model.D_2.train()

            with torch.no_grad():
                _, x_fake, _ = model.G(x_real, emb_org, emb_trg)

            pred_fake_1 = model.D_1(x_fake.detach())
            pred_fake_2 = model.D_2(x_fake.detach())
            pred_real_1 = model.D_1(x_real)
            pred_real_2 = model.D_2(x_real)

            loss_D_fake = F.mse_loss(pred_fake_1, torch.zeros_like(pred_fake_1)) + \
                            F.mse_loss(pred_fake_2, torch.zeros_like(pred_fake_2))
            loss_D_real = F.mse_loss(pred_real_1, torch.ones_like(pred_real_1)) + \
                            F.mse_loss(pred_real_2, torch.ones_like(pred_real_2))
            loss_D = loss_D_fake + loss_D_real

            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

        else:  # 訓練 G
            model.G.train()
            model.D_1.eval()
            model.D_2.eval()

            x_identic, x_identic_psnt, code_real = model.G(x_real, emb_org, emb_trg)

            # GAN loss
            with torch.no_grad():
                pred_fake_1 = model.D_1(x_identic_psnt)
                pred_fake_2 = model.D_2(x_identic_psnt)
                loss_G_GAN = F.mse_loss(pred_fake_1, torch.ones_like(pred_fake_1)) + \
                                F.mse_loss(pred_fake_2, torch.ones_like(pred_fake_2))

            # Identity losses
            loss_G_id = F.mse_loss(x_real, x_identic)
            loss_G_id_psnt = F.mse_loss(x_real, x_identic_psnt)

            # Code semantic consistency
            code_reconst = model.G(x_identic_psnt, emb_org, None)
            loss_G_cd = F.l1_loss(code_real, code_reconst)

            # 預設
            n_layers_D = 5
            num_D = 2
            feat_weights = 5.0 / (n_layers_D + 1)
            D_weights = 1.0 / num_D
            loss_G_Feat = 0

            with torch.no_grad():
                real_feat_1 = model.D_1.get_feature(x_real)
                fake_feat_1 = model.D_1.get_feature(x_identic_psnt)
                real_feat_2 = model.D_2.get_feature(x_real)
                fake_feat_2 = model.D_2.get_feature(x_identic_psnt)

            fea_real = [real_feat_1, real_feat_2]
            fea_fake = [fake_feat_1, fake_feat_2]

            for i in range(num_D):
                for j in range(0, len(fea_fake[i]) - 1):
                    loss_G_Feat += D_weights * feat_weights * F.l1_loss(fea_fake[i][j], fea_real[i][j].detach())

            # 將其加到總 loss 中
            loss_G = loss_G_GAN + loss_G_id*10 + loss_G_id_psnt*10 + loss_G_cd + loss_G_Feat * 10.0


            if step % 2 == 0:
                loss_G_Rec = F.l1_loss(x_identic_psnt, x_real) * 10.0
                loss_G += loss_G_Rec

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

        # 更新進度條
        pbar.set_postfix({
            'loss_G': round(loss_G.item(), 4) if interval == 0 else '-',
            'loss_D': round(loss_D.item(), 4) if interval == 1 else '-'
        })

        if step % save_freq == 0 and interval == 1 and step > 1:
            torch.save({
                'step': step,
                'G': model.G.state_dict(),
                'D_1': model.D_1.state_dict(),
                'D_2': model.D_2.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict()
            }, f'{output_path}/autovc_model_step_{step}.pth')
            print(f"[Step {step}] Model checkpoint saved.")


Step 10000 | D:   2%|▏        | 10000/500000 [2:21:55<115:36:35,  1.18it/s, loss_G=-, loss_D=0.0038]

[Step 10000] Model checkpoint saved.


Step 10001 | G:   2%|▏        | 10001/500000 [2:21:59<261:05:48,  1.92s/it, loss_G=-, loss_D=0.0038]

[Step 10000] Output samples saved.


Step 20000 | D:   4%|▎        | 20000/500000 [4:43:58<113:31:28,  1.17it/s, loss_G=-, loss_D=0.0071]

[Step 20000] Model checkpoint saved.


Step 20001 | G:   4%|▎        | 20001/500000 [4:44:02<255:04:56,  1.91s/it, loss_G=-, loss_D=0.0071]

[Step 20000] Output samples saved.


Step 30000 | D:   6%|▌        | 30000/500000 [7:06:08<111:33:06,  1.17it/s, loss_G=-, loss_D=0.0047]

[Step 30000] Model checkpoint saved.


Step 30001 | G:   6%|▌        | 30001/500000 [7:06:11<250:27:38,  1.92s/it, loss_G=-, loss_D=0.0047]

[Step 30000] Output samples saved.


Step 40000 | D:   8%|▊         | 40000/500000 [9:28:21<109:24:45,  1.17it/s, loss_G=-, loss_D=0.001]

[Step 40000] Model checkpoint saved.


Step 40001 | G:   8%|▊         | 40001/500000 [9:28:25<250:13:19,  1.96s/it, loss_G=-, loss_D=0.001]

[Step 40000] Output samples saved.


Step 50000 | D:  10%|▊       | 50000/500000 [11:50:42<107:10:13,  1.17it/s, loss_G=-, loss_D=0.0012]

[Step 50000] Model checkpoint saved.


Step 50001 | G:  10%|▊       | 50001/500000 [11:50:46<245:02:59,  1.96s/it, loss_G=-, loss_D=0.0012]

[Step 50000] Output samples saved.


Step 60000 | D:  12%|▉       | 60000/500000 [14:13:11<105:19:06,  1.16it/s, loss_G=-, loss_D=0.0004]

[Step 60000] Model checkpoint saved.


Step 60001 | G:  12%|▉       | 60001/500000 [14:13:15<234:42:24,  1.92s/it, loss_G=-, loss_D=0.0004]

[Step 60000] Output samples saved.


Step 70000 | D:  14%|█       | 70000/500000 [16:35:50<104:04:18,  1.15it/s, loss_G=-, loss_D=0.0005]

[Step 70000] Model checkpoint saved.


Step 70001 | G:  14%|█       | 70001/500000 [16:35:54<231:55:42,  1.94s/it, loss_G=-, loss_D=0.0005]

[Step 70000] Output samples saved.


Step 80000 | D:  16%|█▍       | 80000/500000 [18:58:37<99:55:19,  1.17it/s, loss_G=-, loss_D=0.0008]

[Step 80000] Model checkpoint saved.


Step 80001 | G:  16%|█▎      | 80001/500000 [18:58:41<222:01:43,  1.90s/it, loss_G=-, loss_D=0.0008]

[Step 80000] Output samples saved.


Step 90000 | D:  18%|█▌       | 90000/500000 [21:21:14<98:06:42,  1.16it/s, loss_G=-, loss_D=0.0002]

[Step 90000] Model checkpoint saved.


Step 90001 | G:  18%|█▍      | 90001/500000 [21:21:17<221:59:21,  1.95s/it, loss_G=-, loss_D=0.0002]

[Step 90000] Output samples saved.


Step 100000 | D:  20%|█▍     | 100000/500000 [23:44:28<95:28:00,  1.16it/s, loss_G=-, loss_D=0.0002]

[Step 100000] Model checkpoint saved.


Step 100001 | G:  20%|█▏    | 100001/500000 [23:44:32<217:04:02,  1.95s/it, loss_G=-, loss_D=0.0002]

[Step 100000] Output samples saved.


Step 110000 | D:  22%|█▌     | 110000/500000 [26:07:18<93:01:30,  1.16it/s, loss_G=-, loss_D=0.0002]

[Step 110000] Model checkpoint saved.


Step 110001 | G:  22%|█▎    | 110001/500000 [26:07:22<211:22:23,  1.95s/it, loss_G=-, loss_D=0.0002]

[Step 110000] Output samples saved.


Step 120000 | D:  24%|█▋     | 120000/500000 [28:30:01<91:27:05,  1.15it/s, loss_G=-, loss_D=0.0002]

[Step 120000] Model checkpoint saved.


Step 120001 | G:  24%|█▍    | 120001/500000 [28:30:04<203:02:11,  1.92s/it, loss_G=-, loss_D=0.0002]

[Step 120000] Output samples saved.


Step 130000 | D:  26%|█▊     | 130000/500000 [30:52:50<88:32:31,  1.16it/s, loss_G=-, loss_D=0.0001]

[Step 130000] Model checkpoint saved.


Step 130001 | G:  26%|█▌    | 130001/500000 [30:52:54<195:17:45,  1.90s/it, loss_G=-, loss_D=0.0001]

[Step 130000] Output samples saved.


Step 140000 | D:  28%|█▉     | 140000/500000 [33:15:29<85:17:21,  1.17it/s, loss_G=-, loss_D=0.0001]

[Step 140000] Model checkpoint saved.


Step 140001 | G:  28%|█▋    | 140001/500000 [33:15:33<193:53:47,  1.94s/it, loss_G=-, loss_D=0.0001]

[Step 140000] Output samples saved.


Step 150000 | D:  30%|██     | 150000/500000 [35:38:15<83:32:45,  1.16it/s, loss_G=-, loss_D=0.0001]

[Step 150000] Model checkpoint saved.
