In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler

import sys
sys.path.append("bert_vits2/")

import bert_vits2.utils as utils
from bert_vits2.data_utils import (
    TextAudioSpeakerLoader,
    TextAudioSpeakerCollate,
)
from bert_vits2.losses import WavLMLoss
from toolbox import build_models_noise, build_optims

In [None]:
!python bert_gen.py --mode clean

In [2]:
device = torch.device("cuda")
model_name = "BERT_VITS2"
dataset_name = "LibriTTS"
mode = "SPEC"

config_path = f"bert_vits2/configs/{dataset_name.lower()}_{model_name.lower()}.json"
hps = utils.get_hparams_from_file(config_path)
hps.train.batch_size = 27
hps.model_dir = "/root/autodl-tmp/SafeSpeech/checkpoints/base_models"

torch.manual_seed(hps.train.seed)
torch.cuda.manual_seed(hps.train.seed)
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(train_dataset,
                            num_workers=4,
                            shuffle=False,
                            collate_fn=collate_fn,
                            batch_size=hps.train.batch_size,
                            pin_memory=True,
                            drop_last=False)


[32m12-24 13:01:34[0m [1mINFO     [0m| data_utils.py:61 | Init dataset...


100%|██████████| 108/108 [00:00<00:00, 35777.97it/s]

[32m12-24 13:01:34[0m [1mINFO     [0m| data_utils.py:76 | skipped: 0, total: 108





1. initialize models and Generate perturbation

In [3]:
### Build models and optimizers
nets = build_models_noise(hps, device)
net_g, net_d, net_wd, net_dur_disc = nets

optims = build_optims(hps, nets)
optim_g, optim_d, optim_wd, optim_dur_disc = optims

dur_resume_lr = hps.train.learning_rate
wd_resume_lr = hps.train.learning_rate

_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
    net_dur_disc,
    optim_dur_disc,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
if not optim_dur_disc.param_groups[0].get("initial_lr"):
    optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr

_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
    net_g,
    optim_g,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
    net_d,
    optim_d,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
if not optim_g.param_groups[0].get("initial_lr"):
    optim_g.param_groups[0]["initial_lr"] = g_resume_lr
if not optim_d.param_groups[0].get("initial_lr"):
    optim_d.param_groups[0]["initial_lr"] = d_resume_lr

epoch_str = max(epoch_str, 1)
global_step = int(utils.get_steps(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth")))

_, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
    net_wd,
    optim_wd,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
if not optim_wd.param_groups[0].get("initial_lr"):
    optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr

wl = WavLMLoss(
    hps.model.slm.model,
    net_wd,
    hps.data.sampling_rate,
    hps.model.slm.sr,
).to(device)

ERROR:bert_vits2.utils:emb_g.weight is not in the checkpoint


In [4]:
from protect import perturb

noises = [None] * len(train_loader)
max_epoch = 200
epsilon = 8 / 255
alpha = epsilon / 10

for param in net_g.parameters():
    param.requires_grad = False
for param in net_d.parameters():
    param.requires_grad = False
for param in net_dur_disc.parameters():
    param.requires_grad = False
for param in net_wd.parameters():
    param.requires_grad = False
for param in wl.parameters():
    param.requires_grad = False


### Begin to generate perturbation...
for batch_index, batch_data in enumerate(train_loader):
    loss, noises[batch_index] = perturb(hps, [net_g, _, _, _, _], batch_data, 
                                        epsilon, alpha, max_epoch, 10, device)

    print(f"Batch {batch_index}: Loss {loss}")

    torch.cuda.empty_cache()

100%|██████████| 200/200 [03:01<00:00,  1.10it/s]


Batch 0: Loss {'loss_mel': '18.554979', 'loss_nr': '172.343292', 'loss_kl': '18.774075'}


100%|██████████| 200/200 [03:04<00:00,  1.09it/s]


Batch 1: Loss {'loss_mel': '18.611267', 'loss_nr': '170.562912', 'loss_kl': '18.865513'}


100%|██████████| 200/200 [03:12<00:00,  1.04it/s]


Batch 2: Loss {'loss_mel': '18.689875', 'loss_nr': '172.696991', 'loss_kl': '18.187258'}


100%|██████████| 200/200 [03:05<00:00,  1.08it/s]

Batch 3: Loss {'loss_mel': '18.502092', 'loss_nr': '175.578339', 'loss_kl': '19.883799'}





In [5]:
os.makedirs(f"./checkpoints/{dataset_name}/noises/", exist_ok=True)
noise_save_path = f"./checkpoints/{dataset_name}/noises/{model_name}_{mode}_{dataset_name}.noise"
torch.save(noises, noise_save_path)
print(f"Save the noise to {noise_save_path}!")

Save the noise to ./checkpoints/LibriTTS/noises/BERT_VITS2_SPEC_LibriTTS.noise!


2. Save audio and Preprocess

In [7]:
# Save audio
!python save_audio.py --mode SPEC --model BERT_VITS2 --dataset LibriTTS

[32m12-24 13:15:55[0m [1mINFO     [0m| data_utils.py:61 | Init dataset...
100%|██████████████████████████████████████| 108/108 [00:00<00:00, 40165.35it/s]
[32m12-24 13:15:55[0m [1mINFO     [0m| data_utils.py:76 | skipped: 0, total: 108
The noise path is checkpoints/LibriTTS/noises/BERT_VITS2_SPEC_LibriTTS.noise
[0m

In [None]:
!python bert_gen.py --mode SPEC

3. Train models on the protected dataset

In [9]:
hps.train.batch_size = 64
if mode != "clean":
    hps.data.training_files = "filelists/libritts_train_asr.txt.cleaned"
hps.model_dir = "checkpoints/base_models"
assert os.listdir(hps.model_dir) != 4

global global_step
torch.manual_seed(hps.train.seed)
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
    train_dataset,
    num_workers=4,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn,
    batch_size=hps.train.batch_size,
    drop_last=False
)

[32m12-24 13:16:27[0m [1mINFO     [0m| data_utils.py:61 | Init dataset...


100%|██████████| 108/108 [00:00<00:00, 18512.60it/s]

[32m12-24 13:16:27[0m [1mINFO     [0m| data_utils.py:76 | skipped: 0, total: 108





In [10]:
from toolbox import build_models, build_optims, build_schedulers

models = build_models(hps, device)
net_g, net_d, net_wd, net_dur_disc = models

optims = build_optims(hps, models)
optim_g, optim_d, optim_wd, optim_dur_disc = optims

dur_resume_lr = hps.train.learning_rate
wd_resume_lr = hps.train.learning_rate
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
    net_dur_disc,
    optim_dur_disc,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
if not optim_dur_disc.param_groups[0].get("initial_lr"):
    optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr

_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
    net_g,
    optim_g,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
    net_d,
    optim_d,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
if not optim_g.param_groups[0].get("initial_lr"):
    optim_g.param_groups[0]["initial_lr"] = g_resume_lr
if not optim_d.param_groups[0].get("initial_lr"):
    optim_d.param_groups[0]["initial_lr"] = d_resume_lr

epoch_str = max(epoch_str, 1)
global_step = int(utils.get_steps(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth")))

_, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
    utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
    net_wd,
    optim_wd,
    skip_optimizer=(
        hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
    ),
)
if not optim_wd.param_groups[0].get("initial_lr"):
    optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr


schedulers = build_schedulers(hps, optims, epoch_str)
scheduler_g, scheduler_d, scheduler_wd, scheduler_dur_disc = schedulers

scaler = GradScaler(enabled=hps.train.bf16_run)

wl = WavLMLoss(
    hps.model.slm.model,
    net_wd,
    hps.data.sampling_rate,
    hps.model.slm.sr,
).to(device)

ERROR:bert_vits2.utils:emb_g.weight is not in the checkpoint


In [11]:
import time
from train import train

start_time = time.time()
for epoch in range(1, hps.train.epochs + 1):
    loss = train(
        hps,
        [net_g, net_d, net_dur_disc, net_wd, wl],
        [optim_g, optim_d, optim_dur_disc, optim_wd],
        train_loader,
        scaler,
        device
    )

    loss_gen_all, loss_disc_all, loss_dur_disc_all, loss_slm = loss
    
    scheduler_g.step()
    scheduler_d.step()
    scheduler_wd.step()
    scheduler_dur_disc.step()

    end_time = time.time()
    duration = end_time - start_time
    hours, remainder = divmod(duration, 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted_time = "{:02d}:{:02d}:{:02d}".format(int(hours), int(minutes), int(seconds))
    print(f"[{formatted_time}] Epoch {epoch}: G {loss_gen_all:.6f}, D {loss_disc_all:.6f} "
            f"Dur {loss_dur_disc_all:.6f}, Sim {loss_slm:.6f}")

    if epoch % 10 == 0 and mode == "clean":  # Intermediate checkpoints for SEP mode
        save_path = f"checkpoints/{dataset_name}/clean/{model_name}_{mode}_{dataset_name}_{epoch}.pth"
        os.makedirs(f"checkpoints/{dataset_name}/clean", exist_ok=True)
        torch.save(net_g.state_dict(), save_path)

os.makedirs(f"checkpoints/{dataset_name}", exist_ok=True)
save_path = f"checkpoints/{dataset_name}/{model_name}_{mode}_{dataset_name}_{epoch}.pth"
torch.save(net_g.state_dict(), save_path)

[00:00:05] Epoch 1: G 93.337967, D 4.087115 Dur 0.006502, Sim 0.151331
[00:00:10] Epoch 2: G 88.848785, D 3.143491 Dur 0.002086, Sim 0.072668
[00:00:14] Epoch 3: G 87.998581, D 3.001774 Dur 0.000998, Sim 0.173233
[00:00:19] Epoch 4: G 89.167023, D 2.664698 Dur 0.001019, Sim 0.093984
[00:00:24] Epoch 5: G 88.649010, D 2.619440 Dur 0.000481, Sim 0.098045
[00:00:28] Epoch 6: G 87.663193, D 2.584833 Dur 0.000377, Sim 0.118725
[00:00:32] Epoch 7: G 86.267006, D 2.621395 Dur 0.000244, Sim 0.086182
[00:00:37] Epoch 8: G 84.361595, D 2.902445 Dur 0.000157, Sim 0.067962
[00:00:41] Epoch 9: G 85.292183, D 2.858322 Dur 0.000112, Sim 0.051358
[00:00:46] Epoch 10: G 83.716187, D 2.767478 Dur 0.000097, Sim 0.095324
[00:00:50] Epoch 11: G 85.001938, D 2.339326 Dur 0.000045, Sim 0.058034
[00:00:55] Epoch 12: G 84.627625, D 2.229191 Dur 0.000061, Sim 0.062847
[00:00:59] Epoch 13: G 83.859947, D 2.435069 Dur 0.000020, Sim 0.058307
[00:01:04] Epoch 14: G 82.848976, D 2.857036 Dur 0.000040, Sim 0.088750
[

4. Evaluation

In [12]:
from evaluate import evaluation

evaluation(hps.data.testing_files, net_g, model_name, dataset_name, mode, device)

  0%|          | 0/26 [00:00<?, ?it/s]100%|██████████| 26/26 [00:07<00:00,  3.55it/s]


Mode SPEC, MCD:  {14.771323384971254}


100%|██████████| 52/52 [00:49<00:00,  1.04it/s]


Mode SPEC: GT WER is 0.000000, Syn WER is 0.996104


100%|██████████| 26/26 [00:01<00:00, 24.35it/s]


Mode SPEC on LibriTTS, SIM 0.204721, ASR 0.23076923.
