In [1]:
import numpy as np
import glob

from tqdm import tqdm
from torchaudio import load, save
import torch
import os
from argparse import ArgumentParser

from sgmse.backbones.shared import BackboneRegistry
from sgmse.data_module import SpecsDataModule
from sgmse.sdes import SDERegistry
from sgmse.sampling import PredictorRegistry, CorrectorRegistry, OperatorRegistry, PosteriorRegistry, SchedulerRegistry, PosteriorRegistry
from sgmse.model import ScoreModel
from sgmse.util.other import *
from sgmse.util.graphics import *

import matplotlib.pyplot as plt


ckpt = "/export/home/lemercier/code/_public_repos/derevdps/.logs/sde=EDM_backbone=ncsnpp_data=wsj0_ch=1/version_5/checkpoints/epoch=146.ckpt"
test_dir = "/data/lemercier/databases/wsj0_derev_with_rir/audio/tt/clean"
enhanced_dir = ".exp/test"
os.makedirs(enhanced_dir, exist_ok=True)

# Load score model
model = ScoreModel.load_from_checkpoint(ckpt)
model.eval(no_ema=False)
torch.cuda.set_device(f'cuda:0')
model.cuda()

ScoreModel(
  (dnn): NCSNpp(
    (act): SiLU()
    (output_layer): Conv2d(2, 2, kernel_size=(1, 1), stride=(1, 1))
    (pyramid_upsample): Upsample()
    (pyramid_downsample): Downsample()
    (all_modules): ModuleList(
      (0): GaussianFourierProjection()
      (1): Linear(in_features=256, out_features=512, bias=True)
      (2): Linear(in_features=512, out_features=512, bias=True)
      (3): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ResnetBlockBigGANpp(
        (GroupNorm_0): GroupNorm(32, 128, eps=1e-06, affine=True)
        (Conv_0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (Dense_0): Linear(in_features=512, out_features=128, bias=True)
        (GroupNorm_1): GroupNorm(32, 128, eps=1e-06, affine=True)
        (Dropout_0): Dropout(p=0.0, inplace=False)
        (Conv_1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act): SiLU()
      )
      (5): ResnetBlockBigGANpp(
        (GroupN

In [2]:

files = sorted(glob.glob(os.path.join(test_dir, "*.wav")))[: 1]

with torch.no_grad():
    for f in files:
        x_audio, _ = torchaudio.load(f)
        x = model._forward_transform(model._stft(x_audio)).unsqueeze(0).cuda()
        x = pad_spec(x)

        t = model.sample_time(x)
        mean, std = model.sde.marginal_prob(x, t, None)
        z = torch.randn_like(x)
        if std.ndim < x.ndim:
            std = std.view(*std.size(), *((1,)*(x.ndim - std.ndim)))
        y = mean + std * z

        tweedie = model(perturbed_data, t, score_conditioning=[])
        tweedie_audio = model.to_audio(tweedie, x_audio.size(-1))
        y_audio = model.to_audio(y, x_audio.size(-1))

        torchaudio.save(f'{enhanced_dir}/{os.path.basename(f)}', x_audio.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)
        torchaudio.save(f'{enhanced_dir}/{os.path.basename(f)[: -4]}_tweedie.wav', tweedie_audio.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)
        torchaudio.save(f'{enhanced_dir}/{os.path.basename(f)[: -4]}_noisy.wav', y_audio.type(torch.float32).cpu().squeeze().unsqueeze(0), 16000)

NameError: name 'perturbed_data' is not defined