In [21]:
import numpy as np
import glob

from tqdm import tqdm
from torchaudio import load, save
import torch
import os
from argparse import ArgumentParser

from sgmse.backbones.shared import BackboneRegistry
from sgmse.data_module import SpecsDataModule
from sgmse.sdes import SDERegistry
from sgmse.sampling import PredictorRegistry, CorrectorRegistry, OperatorRegistry, PosteriorRegistry, SchedulerRegistry, PosteriorRegistry
from sgmse.model import ScoreModel
from sgmse.util.other import *
from sgmse.util.graphics import *

import matplotlib.pyplot as plt

test_dir = "/data/lemercier/databases/wsj0_derev_with_rir/audio/tt/clean"

# ckpt = "/export/home/lemercier/code/_public_repos/derevdps/.logs/sde=EDM_backbone=ncsnpp_data=wsj0_ch=1/version_6/checkpoints/epoch=63.ckpt"
# enhanced_dir = ".exp/test_edm"
# # kwargs = dict(sampler_type="karras", probability_flow=True, N=50, scheduler="edm",
# kwargs = dict(sampler_type="song", probability_flow=True, N=50, scheduler="edm",
#         # predictor="euler-heun",
#         predictor="euler-maruyama",
#         posterior="none", operator=None, A=None, zeta=0., zeta_schedule="none",
#         # corrector="none", r=0., corrector_steps=0, 
#         corrector="ald", r=0.4, corrector_steps=1, 
#         noise_std=1.007, smin=0., smax=0., churn=0.
# )

ckpt = "/export/home/lemercier/code/_public_repos/derevdps/.logs/sde=VESDE_backbone=ncsnpp_data=wsj0_ch=1/version_3/checkpoints/epoch=297.ckpt"
# enhanced_dir = ".exp/test_song"
# enhanced_dir = ".exp/test_song_karrassampler_nocorr_em"
# enhanced_dir = ".exp/test_song_songsampler_corr_em"
enhanced_dir = ".exp/test_song_karrassampler_nocorr_eh"
kwargs = dict(sampler_type="karras", probability_flow=True, N=50, scheduler="ve",
# kwargs = dict(sampler_type="song", probability_flow=True, N=50, scheduler="ve",
        predictor="euler-heun",
        # predictor="euler-maruyama",
        posterior="none", operator=None, A=None, zeta=0., zeta_schedule="none",
        # corrector="ald", r=0.4, corrector_steps=1, 
        corrector="none", r=0., corrector_steps=0, 
)

os.makedirs(enhanced_dir, exist_ok=True)

# Load score model
model = ScoreModel.load_from_checkpoint(ckpt)
model.eval(no_ema=False)
torch.cuda.set_device(f'cuda:0')
model.cuda()

ScoreModel(
  (dnn): NCSNpp(
    (act): SiLU()
    (output_layer): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1))
    (pyramid_upsample): Upsample()
    (pyramid_downsample): Downsample()
    (all_modules): ModuleList(
      (0): GaussianFourierProjection()
      (1): Linear(in_features=256, out_features=512, bias=True)
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ResnetBlockBigGANpp(
        (GroupNorm_0): GroupNorm(32, 128, eps=1e-06, affine=True)
        (Conv_0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Dense_0): Linear(in_features=512, out_features=128, bias=True)
        (GroupNorm_1): GroupNorm(32, 128, eps=1e-06, affine=True)
        (Dropout_0): Dropout(p=0.0, inplace=False)
        (Conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act): SiLU()
      )
      (5): ResnetBlockBigGANpp(
        (GroupN

In [6]:
### Estimating Tweedie Denoiser ###

files = sorted(glob.glob(os.path.join(test_dir, "*.wav")))[: 1]

with torch.no_grad():
    for f in files:
        x_audio, _ = torchaudio.load(f)
        norm_factor = x_audio.abs().max()
        x = model._forward_transform(model._stft(x_audio/norm_factor)).unsqueeze(0).cuda()
        x = pad_spec(x)

        for t in torch.linspace(0.5, 1, 10).cuda():
            sigma = model.sde.scheduler.continuous_step(t)
            # sigma = t
            mean, std = model.sde.marginal_prob(x, sigma, None)
            print(t, mean.abs().max(), std)
            z = torch.randn_like(x)
            if std.ndim < x.ndim:
                std = std.view(*std.size(), *((1,)*(x.ndim - std.ndim)))
            y = mean + std * z

            tweedie = model(y, sigma, score_conditioning=[])
            tweedie_audio = model.to_audio(tweedie.squeeze(0), x_audio.size(-1))
            y_audio = model.to_audio(y.squeeze(0), x_audio.size(-1))

            visualize_one(x, spec_path=enhanced_dir, name=f"{t:.2f}_x")
            visualize_one(tweedie, spec_path=enhanced_dir, name=f"{t:.2f}_tweedie")
            visualize_one(y, spec_path=enhanced_dir, name=f"{t:.2f}_noisy")

            torchaudio.save(f'{enhanced_dir}/{os.path.basename(f)[: -4]}_{t:.2f}_tweedie.wav', tweedie_audio.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)
            torchaudio.save(f'{enhanced_dir}/{os.path.basename(f)[: -4]}_{t:.2f}_noisy.wav', y_audio.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)
        
        torchaudio.save(f'{enhanced_dir}/{os.path.basename(f)}', x_audio.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)
        

tensor(0.5000, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(0.3102, device='cuda:0')
tensor(0.5556, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(0.5517, device='cuda:0')
tensor(0.6111, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(0.9393, device='cuda:0')
tensor(0.6667, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(1.5401, device='cuda:0')
tensor(0.7222, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(2.4442, device='cuda:0')
tensor(0.7778, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(3.7697, device='cuda:0')
tensor(0.8333, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(5.6690, device='cuda:0')
tensor(0.8889, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(8.3356, device='cuda:0')
tensor(0.9444, device='cuda:0') tensor(2.4248, device='cuda:0') tensor(12.0124, device='cuda:0')
tensor(1., device='cuda:0') tensor(2.4248, device='cuda:0') tensor(17.0000, device='cuda:0')


In [22]:
### Unconditional Sampling ###

N = 1
len_signal = 4.

with torch.no_grad():
    for n in range(N):
        y = torch.zeros(1, int(len_signal*model.data_module.sample_rate)).cuda()
        # for t in torch.linspace(0.5, 1, 10):
        for t in torch.linspace(0.82, 0.84, 10):
        # for t in torch.linspace(0.8, 0.82, 10):
        # for t in torch.linspace(0.85, 0.9, 5):
            model.sde.set_T(t)
            print(model.sde.T)
            x = model.unconditional_sampling(y, **kwargs)
            print(x.abs().max())
            x /= x.abs().max()


            visualize_one(model._stft(x), spec_path=enhanced_dir, name=f"{n}_{t:.3f}")
            torchaudio.save(f'{enhanced_dir}/{n}_{t:.3f}.wav', x.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)

  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8200)
tensor(0.8200) tensor([1.2852], device='cuda:0')


100%|██████████| 50/50 [00:12<00:00,  4.02it/s, distance=0]


tensor(0.0141)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8222)
tensor(0.8222) tensor([1.3268], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.75it/s, distance=0]


tensor(0.0332)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8244)
tensor(0.8244) tensor([1.3698], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.81it/s, distance=0]


tensor(0.0178)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8267)
tensor(0.8267) tensor([1.4142], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.82it/s, distance=0]


tensor(1.3523)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8289)
tensor(0.8289) tensor([1.4600], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.75it/s, distance=0]


tensor(0.0935)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8311)
tensor(0.8311) tensor([1.5073], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.79it/s, distance=0]


tensor(0.0183)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8333)
tensor(0.8333) tensor([1.5561], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.69it/s, distance=0]


tensor(0.3224)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8356)
tensor(0.8356) tensor([1.6065], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.73it/s, distance=0]


tensor(0.7442)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8378)
tensor(0.8378) tensor([1.6586], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.69it/s, distance=0]


tensor(0.3733)


  0%|          | 0/50 [00:00<?, ?it/s]

tensor(0.8400)
tensor(0.8400) tensor([1.7123], device='cuda:0')


100%|██████████| 50/50 [00:13<00:00,  3.65it/s, distance=0]


tensor(0.5454)
