In [1]:
import os

import torch
from pathlib import Path
from model import UNET, Separator

In [2]:
model_dir = Path('./models')
model_name = 'my_checkpoint.pth'
model_file = model_dir / model_name
checkpoint = torch.load('./models/' + model_name, weights_only=True)

In [3]:
model = UNET()
model.load_state_dict(checkpoint['state_dict'])
model.eval()

UNET(
  (ups): ModuleList(
    (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (1): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU(inplace=True)
      )
    )
    (2): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (3): DoubleConv(
      (conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (3): BatchNorm2d(256, eps=1e-

In [4]:
import musdb

musdb_testdata = musdb.DB('./musdb', download=False, subsets='test')

In [5]:
track = musdb_testdata[0]
mix_audio = torch.as_tensor(track.audio.T, dtype=torch.float32)
mix_audio.shape

torch.Size([2, 300032])

In [6]:
sep = Separator(n_fft=1024, hop_length=512, model=model, device='cuda')

prediction = sep(mix_audio.unsqueeze(0))

In [7]:
prediction[0]

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0')

In [8]:
import numpy as np

pred_batch = np.load('./saved_spectrograms/predbatch_0.npy')
pred_batch.shape

(4, 2, 513, 585)

In [14]:
pred = torch.from_numpy(pred_batch[0]).to('cuda')
pred

tensor([[[ 5.5052e+00,  4.7443e-01,  2.3468e+00,  ...,  7.8862e+00,
           1.0306e+01,  2.2304e+01],
         [ 4.9328e+00,  2.9134e+00,  2.8284e+00,  ..., -2.8186e+00,
           8.2736e+01,  1.1040e+02],
         [ 5.4619e+00,  8.7864e+00,  3.0880e+00,  ...,  5.1480e+01,
           2.2413e+02,  1.7257e+02],
         ...,
         [-8.7637e-02, -1.0132e-01, -9.4307e-02,  ..., -9.4227e-02,
          -1.0068e-01, -1.1208e-01],
         [-9.1752e-02, -1.0762e-01, -9.4337e-02,  ..., -9.0155e-02,
          -9.4531e-02, -1.1438e-01],
         [-9.0652e-02, -1.0613e-01, -9.7481e-02,  ..., -1.0600e-01,
          -1.0426e-01, -1.0304e-01]],

        [[-8.3599e+00,  1.8399e-01, -4.6741e+00,  ...,  5.5093e+00,
           9.2390e-01, -5.3253e+00],
         [-4.9956e+00, -3.5299e+00, -1.9806e+00,  ...,  1.1111e+01,
           6.5718e+01,  1.0056e+02],
         [-1.6940e+00,  2.1643e+00,  3.1759e+00,  ...,  3.3509e+01,
           1.6663e+02,  1.5374e+02],
         ...,
         [-2.8378e-02, -2

In [12]:
from transforms import make_filterbanks

stft, istft = make_filterbanks(
            n_fft=1024,
            hop_length=512,
            center=False,
            sample_rate=44100.0,
            device='cuda'
        )

In [16]:
pred_audio = istft(pred)
pred_audio.shape

torch.Size([2, 299008])

In [18]:
from IPython.display import Audio

Audio(pred_audio.cpu(), rate=44100)

  return scaled.astype("<h").tobytes(), nchan
