In [1]:
import os
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchaudio.transforms as T
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from scipy.io.wavfile import write
import commons
import utils
from data_utils import (
  TextAudioLoader,
  TextAudioCollate,
  DistributedBucketSampler
)
from models import NFTAudio
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from torchmetrics import SignalNoiseRatio
import math

################################################################################
###          (please add 'export KALDI_ROOT=<your_path>' in your $HOME/.profile)
###          (or run as: KALDI_ROOT=<your_path> python <your_script>.py)
################################################################################



In [2]:
run_name = 'noise_agmt_0.02_no_decay_6july'
step = 10000
model_path = './logs/{}/NFTAudio_{}.pth'.format(run_name, step)
wave_length = 8192
msg_length = 100
nft_model = NFTAudio(wave_length, msg_length).cuda(0)
check_point = torch.load(model_path)
# print(check_point)
nft_model.load_state_dict(check_point['model'])
nft_model.eval()

NFTAudio(
  (watermarker): Message_Encoder(
    (conv_pre): Conv1d(192, 512, kernel_size=(7,), stride=(1,), padding=(3,))
    (ups): ModuleList(
      (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
      (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
      (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(2,), padding=(1,))
      (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (resblocks): ModuleList(
      (0): ResBlock2(
        (convs): ModuleList(
          (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        )
      )
      (1): ResBlock2(
        (convs): ModuleList(
          (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
          (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        )
      )
      (2): ResBlo

In [3]:
config_save_path = './configs/ljs_base.json'
with open(config_save_path, "r") as f:
    data = f.read()
import json
from utils import HParams
config = json.loads(data)
hps = HParams(**config)
hps.train.segment_size = 8192
hps.train.eval_interval = 10000
hps.train.log_interval = 500
hps.train.batch_size = 2
msg_dim = 100
collate_fn = TextAudioCollate(msg_dim)
eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data, msg_dim)
eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=True,
    batch_size=hps.train.batch_size, pin_memory=True,
    drop_last=False, collate_fn=collate_fn)
print()

In [5]:

l1_loss  = torch.nn.L1Loss()
wav_index = 0
amplitude = np.iinfo(np.int16).max
Total_errors = []
SNR_list = []

def audio_ptb(input_audio, ptb_type):
    if ptb_type=='none':
        new_audio = input_audio + 0.02*torch.randn_like(input_audio)
        new_audio = new_audio.unsqueeze(1)
        new_audio = torch.clamp(new_audio, -1, 1)
        
    return new_audio
    

for batch_idx, (message,spec, spec_lengths, wav, wav_lengths) in enumerate(eval_loader):
    spec    = spec.cuda(0)
    wav, wav_lengths = wav.cuda(0), wav_lengths.cuda(0)
    message = message.cuda(0)
    wav = wav.squeeze()

    watermarked_wav = torch.zeros_like(wav)# don't use empty like function
    watermarked_wav.copy_(wav)
    ptb_watermarked_wav = torch.zeros_like(wav)
    ptb_watermarked_wav.copy_(wav)
    
    error_rate = 0
    ptb_error_rate = 0
    watermark_times = wav_lengths[wav_index].item()//8192
    for i in range(watermark_times):
        wav_slice = torch.zeros_like(wav[wav_index,i*8192:(i+1)*8192])
        wav_slice.copy_(wav[wav_index,i*8192:(i+1)*8192])
        wav_slice = wav_slice.repeat(wav.size()[0],1)

        message_new = torch.zeros_like(message[wav_index,:])
        message_new.copy_(message[wav_index,:])
        message_new = message_new.repeat(wav.size()[0],1)

        recon_wav = torch.zeros_like(wav_slice)
        recon_msg = torch.zeros_like(message_new)
        recon_wav, recon_msg = nft_model(wav_slice, message_new)
        watermarked_wav[wav_index, i*8192:(i+1)*8192] = recon_wav[wav_index,:]
        
        ptb_type = 'none' # 'none', 'noise', 'rir'
        ptb_recon_wav = audio_ptb(recon_wav, ptb_type)
        ptb_watermarked_wav[wav_index, i*8192:(i+1)*8192] = recon_wav[wav_index,:]
        ptb_recon_msg = nft_model.reconstructor(ptb_recon_wav)
        

        err = torch.abs(message[wav_index].squeeze().round() -recon_msg[wav_index].squeeze().round()).sum()
        ptb_err = torch.abs(message[wav_index].squeeze().round() -ptb_recon_msg[wav_index].squeeze().round()).sum()
        
        error_rate += err.item()
        ptb_error_rate += ptb_err.item()
        print(err.item(), ptb_err.item())
        Total_errors.append(err.item())

    entire_ori_mel = mel_spectrogram_torch(
          wav[wav_index:wav_index+1,:].float(), 
          hps.data.filter_length, 
          hps.data.n_mel_channels, 
          hps.data.sampling_rate, 
          hps.data.hop_length, 
          hps.data.win_length, 
          hps.data.mel_fmin, 
          hps.data.mel_fmax
        )

    recon_mel = mel_spectrogram_torch(
          watermarked_wav[wav_index:wav_index+1,:].float(), 
          hps.data.filter_length, 
          hps.data.n_mel_channels, 
          hps.data.sampling_rate, 
          hps.data.hop_length, 
          hps.data.win_length, 
          hps.data.mel_fmin, 
          hps.data.mel_fmax
        )

    delta_mel = mel_spectrogram_torch(
          watermarked_wav[wav_index:wav_index+1,:].float()-wav[wav_index:wav_index+1,:].float(), 
          hps.data.filter_length, 
          hps.data.n_mel_channels, 
          hps.data.sampling_rate, 
          hps.data.hop_length, 
          hps.data.win_length, 
          hps.data.mel_fmin, 
          hps.data.mel_fmax
        )
snr = SNR(watermarked_wav[wav_index,:], wav[wav_index,:])
SNR_list.append(snr.item())
print(snr.item())
save_ori_wav = np.asarray(wav[wav_index:wav_index+1,:].cpu().detach().numpy()*amplitude)[0]
save_wtm_wav = np.asarray(watermarked_wav[wav_index:wav_index+1,:].cpu().detach().numpy()*amplitude)[0]
delta = save_wtm_wav - save_ori_wav
log_path = os.path.join(hps.save_dir, 'tensorboard/clean_evaluate')
if not os.path.exists(log_path):
    os.mkdir(log_path)
write("{}/original_{}.wav".format(log_path, batch_idx), 22050, save_ori_wav.astype(np.int16))
write("{}/watermarked_{}.wav".format(log_path, batch_idx), 22050, save_wtm_wav.astype(np.int16))
write("{}/delta_{}.wav".format(log_path, batch_idx), 22050, delta.astype(np.int16))

x = np.arange(0, delta.shape[0])
plt.figure()
plt.plot(x, save_ori_wav, label = 'Original')
plt.plot(x, save_wtm_wav, label = 'Watermarked')
plt.plot(x, delta, label = 'delta')
plt.legend()
plt.savefig('{}/wav_comparison_{}.pdf'.format(log_path, batch_idx),bbox_inches='tight', dpi=600)
plt.close() 

print('{} error: {}'.format(batch_idx, error_rate/watermark_times))
writer_eval.add_scalar('error_rate',      error_rate/watermark_times, batch_idx)
writer_eval.add_image('input/ori_mel',     utils.plot_spectrogram_to_numpy(entire_ori_mel[0].data.cpu().numpy()), batch_idx, dataformats='HWC')
writer_eval.add_image('input/recon_mel',   utils.plot_spectrogram_to_numpy(recon_mel[0].data.cpu().numpy()), batch_idx, dataformats='HWC')
writer_eval.add_image('input/recon_mel',   utils.plot_spectrogram_to_numpy(delta_mel[0].data.cpu().numpy()), batch_idx, dataformats='HWC')
writer_eval.add_audio('audio/original',    wav[wav_index:wav_index+1,:], batch_idx, sample_rate=22050)
writer_eval.add_audio('audio/watermarked', watermarked_wav[wav_index:wav_index+1,:], batch_idx, sample_rate=22050)
writer_eval.add_audio('audio/delta', watermarked_wav[wav_index:wav_index+1,:]-wav[wav_index:wav_index+1,:], batch_idx, sample_rate=22050)
print('Mean: {}  STD: {}'.format(np.mean(Total_errors), np.std(Total_errors)))
np.savetxt("{}/accuracy.csv".format(log_path), Total_errors, delimiter =", ", fmt ='%1.9f')
np.savetxt("{}/snr.csv".format(log_path), SNR_list, delimiter =", ", fmt ='%1.9f')

0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0


NameError: name 'SNR' is not defined