In [1]:
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchaudio.transforms as T
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from scipy.io.wavfile import write
import commons
import utils
from data_utils import (
  TextAudioLoader,
  TextAudioCollate,
  DistributedBucketSampler
)
from models import NFTAudio
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from torchmetrics import SignalNoiseRatio
import math
from torch_audiomentations import Compose, Gain, PolarityInversion, AddBackgroundNoise, PitchShift, ApplyImpulseResponse, AddColoredNoise,HighPassFilter,LowPassFilter, Shift
from pathlib import Path
import torchaudio

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################

  '"torchaudio.USE_SOUNDFILE_LEGACY_INTERFACE" flag is deprecated and will be removed in 0.9.0. '


In [2]:
run_name = 'five_agmt_run'
step = 70000
model_path = './logs/{}/NFTAudio_{}.pth'.format(run_name, step)
wave_length = 8192
msg_length = 100
nft_model = NFTAudio(wave_length, msg_length).cuda(0)
check_point = torch.load(model_path)
# print(check_point)
nft_model.load_state_dict(check_point['model'])
nft_model.eval()

NFTAudio(
  (watermarker): Message_Encoder(
    (conv_pre): Conv1d(192, 512, kernel_size=(7,), stride=(1,), padding=(3,))
    (ups): ModuleList(
      (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
      (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
      (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(2,), padding=(1,))
      (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (resblocks): ModuleList(
      (0): ResBlock2(
        (convs): ModuleList(
          (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        )
      )
      (1): ResBlock2(
        (convs): ModuleList(
          (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
          (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        )
      )
      (2): ResBlo

In [3]:
config_save_path = './configs/ljs_base.json'
with open(config_save_path, "r") as f:
    data = f.read()
import json
from utils import HParams
config = json.loads(data)
hps = HParams(**config)
hps.train.segment_size = 8192
hps.train.eval_interval = 10000
hps.train.log_interval = 500
hps.train.batch_size = 2
msg_dim = 100
collate_fn = TextAudioCollate(msg_dim)
eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data, msg_dim)
eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=True,
    batch_size=hps.train.batch_size, pin_memory=True,
    drop_last=False, collate_fn=collate_fn)

In [25]:

torchaudio.set_audio_backend("sox_io")

def audio_ptb(input_audio, ptb_type):
    input_audio = input_audio.unsqueeze(1)
    if ptb_type=='none':
        new_audio = input_audio
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='gain':
        agmt = Gain(min_gain_in_db=9,max_gain_in_db=10,p=1)
        new_audio = agmt(input_audio, sample_rate=22050)
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='noise':
        new_audio = input_audio + 0.01*torch.randn_like(input_audio)
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='inversion':
        agmt = PolarityInversion(p=1)
        new_audio = agmt(input_audio, sample_rate = 22050)
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='shift':
        agmt = Shift(min_shift=0.1,max_shift=0.5,shift_unit="fraction",rollover=True,p=1)
        new_audio = agmt(input_audio, sample_rate = 22050)
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='env_background':
        env_wav_dir = './background_noise/environment/'
        agmt = AddBackgroundNoise(env_wav_dir, 1, p=1.0)
        new_audio = agmt(input_audio, sample_rate = 22050)
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='music_background':
        env_wav_dir = './background_noise/music/'
#         env_wav_dir = './test_audios/'
        agmt = AddBackgroundNoise(env_wav_dir, 1, p=1.0)
        new_audio = agmt(input_audio, sample_rate = 22050)
        new_audio = torch.clamp(new_audio, -1, 1)
    if ptb_type=='rir':
        rir_dir = './background_noise/rir_audios/'
        agmt = ApplyImpulseResponse(p=1,ir_paths = rir_dir, sample_rate = 22050)
        new_audio = agmt(input_audio, sample_rate = 22050)
        new_audio = torch.clamp(new_audio, -1, 1)
    return new_audio

def SNR(watermarked, original):
    # audio data type: tensor; dimension batch*8192.
    snr = SignalNoiseRatio().cuda()
    rate = snr(watermarked, original)
    return rate    

l1_loss  = torch.nn.L1Loss()
wav_index = 0
amplitude = np.iinfo(np.int16).max
ori_errors = []
ptb_errors = []
SNR_list = []
ptb_SNR_list = []
with torch.no_grad():
    for batch_idx, (message,spec, spec_lengths, wav, wav_lengths) in enumerate(eval_loader):
        spec    = spec.cuda(0)
        wav, wav_lengths = wav.cuda(0), wav_lengths.cuda(0)
        message = message.cuda(0)
        wav = wav.squeeze()

        watermarked_wav = torch.zeros_like(wav)# don't use empty like function
        watermarked_wav.copy_(wav)
        ptb_watermarked_wav = torch.zeros_like(wav)
        ptb_watermarked_wav.copy_(wav)

        error_rate = 0
        ptb_error_rate = 0
        watermark_times = wav_lengths[wav_index].item()//8192
        for i in range(watermark_times):
            wav_slice = torch.zeros_like(wav[wav_index,i*8192:(i+1)*8192])
            wav_slice.copy_(wav[wav_index,i*8192:(i+1)*8192])
            wav_slice = wav_slice.repeat(wav.size()[0],1)

            message_new = torch.zeros_like(message[wav_index,:])
            message_new.copy_(message[wav_index,:])
            message_new = message_new.repeat(wav.size()[0],1)

            recon_wav = torch.zeros_like(wav_slice)
            recon_msg = torch.zeros_like(message_new)
            recon_wav, recon_msg,_,_ = nft_model(wav_slice, message_new)
            watermarked_wav[wav_index, i*8192:(i+1)*8192] = recon_wav[wav_index,:]

            err = torch.abs(message[wav_index].squeeze().round() -recon_msg[wav_index].squeeze().round()).sum()
            ori_errors.append(err.item())

        entire_ori_mel = mel_spectrogram_torch(
              wav[wav_index:wav_index+1,:].float(), 
              hps.data.filter_length, 
              hps.data.n_mel_channels, 
              hps.data.sampling_rate, 
              hps.data.hop_length, 
              hps.data.win_length, 
              hps.data.mel_fmin, 
              hps.data.mel_fmax
            )

        recon_mel = mel_spectrogram_torch(
              watermarked_wav[wav_index:wav_index+1,:].float(), 
              hps.data.filter_length, 
              hps.data.n_mel_channels, 
              hps.data.sampling_rate, 
              hps.data.hop_length, 
              hps.data.win_length, 
              hps.data.mel_fmin, 
              hps.data.mel_fmax
            )

        delta_mel = mel_spectrogram_torch(
              watermarked_wav[wav_index:wav_index+1,:].float()-wav[wav_index:wav_index+1,:].float(), 
              hps.data.filter_length, 
              hps.data.n_mel_channels, 
              hps.data.sampling_rate, 
              hps.data.hop_length, 
              hps.data.win_length, 
              hps.data.mel_fmin, 
              hps.data.mel_fmax
            )
        
        ptb_type = 'rir' # 'none', 'noise', 'rir'
        ptb_watermarked_wav = audio_ptb(watermarked_wav, ptb_type)
        for i in range(watermark_times):
            ptb_wav_slice = ptb_watermarked_wav[wav_index:wav_index+1,:,i*8192:(i+1)*8192]
            ptb_recon_msg = nft_model.reconstructor(ptb_wav_slice)
            ptb_err = torch.abs(message[wav_index].squeeze().round() -ptb_recon_msg[wav_index].squeeze().round()).sum()
            ptb_errors.append(ptb_err.item())
        ptb_watermarked_wav = ptb_watermarked_wav.squeeze()
        print(np.mean(ori_errors), np.mean(ptb_errors))
        snr = SNR(watermarked_wav[wav_index,:], wav[wav_index,:])
        ptb_snr = SNR(ptb_watermarked_wav[wav_index,:], wav[wav_index,:])
        SNR_list.append(snr.item())
        ptb_SNR_list.append(ptb_snr.item())
        save_ori_wav = np.asarray(wav[wav_index:wav_index+1,:].cpu().detach().numpy()*amplitude)[0]
        save_wtm_wav = np.asarray(watermarked_wav[wav_index:wav_index+1,:].cpu().detach().numpy()*amplitude)[0]
        save_ptb_wtm_wav = np.asarray(ptb_watermarked_wav[wav_index:wav_index+1,:].cpu().detach().numpy()*amplitude)[0]
        delta = save_wtm_wav - save_ori_wav

        model_dir = os.path.join("./results", run_name)
        log_path = '{}/evaluations/{}_ptb/'.format(model_dir, ptb_type)
        if not os.path.exists(log_path):
            os.makedirs(log_path)

        write("{}/original_{}.wav".format(log_path, batch_idx), 22050, save_ori_wav.astype(np.int16))
        write("{}/watermarked_{}.wav".format(log_path, batch_idx), 22050, save_wtm_wav.astype(np.int16))
        write("{}/{}_ptb_watermarked_{}.wav".format(log_path, ptb_type,batch_idx), 22050, save_ptb_wtm_wav.astype(np.int16))
        
        if batch_idx>=20:
            break
x = np.arange(0, delta.shape[0])
plt.figure()
plt.plot(x, save_ori_wav, label = 'Original')
plt.plot(x, save_wtm_wav, label = 'Watermarked')
plt.plot(x, delta, label = 'delta')
plt.legend()
plt.savefig('{}/wav_comparison_{}.pdf'.format(log_path, batch_idx),bbox_inches='tight', dpi=600)
plt.close() 

x = np.arange(0, delta.shape[0])
plt.figure()
plt.plot(x, save_ori_wav, label = 'Original')
plt.plot(x, save_wtm_wav, label = 'Watermarked')
plt.plot(x, save_ptb_wtm_wav, label = '{}_Watermarked'.format(ptb_type))
plt.legend()
plt.savefig('{}/ptb_vs_nonptb_{}.pdf'.format(log_path, batch_idx),bbox_inches='tight', dpi=600)
plt.close()

np.savetxt("{}/ber.csv".format(log_path), ori_errors, delimiter =", ", fmt ='%1.9f')
np.savetxt("{}/{}_ptb_ber.csv".format(log_path, ptb_type), ptb_errors, delimiter =", ", fmt ='%1.9f')
np.savetxt("{}/snr.csv".format(log_path), SNR_list, delimiter =", ", fmt ='%1.9f')

3.0 46.0
3.409090909090909 48.20454545454545
3.2535211267605635 48.690140845070424
2.6842105263157894 47.14736842105263
3.1875 48.848214285714285
3.473684210526316 48.88721804511278
3.698717948717949 48.333333333333336
3.9943181818181817 49.00568181818182
3.837696335078534 49.146596858638745
4.033333333333333 50.01904761904762
4.097777777777778 50.08
4.171428571428572 49.76326530612245
4.07089552238806 49.649253731343286
4.132404181184669 49.57142857142857
4.287539936102236 49.20447284345048
4.27027027027027 49.174174174174176
4.111731843575419 48.98882681564246
4.0557029177718835 48.644562334217504
4.002518891687657 48.40302267002519
3.9683698296836982 48.57664233576642
4.0233644859813085 48.36682242990654
