In [1]:
from model import BSRNN
import torch
from utils import torch_mdct, torch_imdct, calc_loss
from dataset import CLIPSDXDataset
import os

In [2]:
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus)))

In [3]:
data_dir = "/mnt/BigChonk/LinuxDownloads/moisesdb23_bleeding_v1.0_prepared"
interferer_files = ['bass.wav','other.wav','drums.wav']
model = BSRNN(target_stem="vocals",sample_rate=44100,n_fft=2048,
              hop_length=512,channels=2,fc_dim=128,
              num_band_seq_module=12,num_mixtures=2)
train_dataset = CLIPSDXDataset(root=data_dir,interferer_files = interferer_files,
                              seq_duration=1, random_chunks=True)
valid_dataset = CLIPSDXDataset(root=data_dir, split="valid", interferer_files=interferer_files)
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                           batch_size=4, shuffle=True,
                                           num_workers=16,drop_last=True
                                          )
# valid_loader = torch.utils.data.DataLoader(valid_dataset,
#                                            batch_size=1, shuffle=False,
#                                            num_workers = 16
#                                           )
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

196it [00:00, 2041.19it/s]
7it [00:00, 11379.89it/s]


In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
model = torch.nn.DataParallel(model)
model.to(device)
print(next(model.parameters()).device)

cuda:0


In [6]:
train_batch = next(iter(train_loader))
mixture, sources, target_stems = [i.to(device) for i in train_batch]

In [7]:
from einops import rearrange
def prepare_input_wav(wav):
    wav = rearrange(wav,'b c t -> (b c) t')
    spec = torch_mdct(wav, window_length=2048,window_type='kbd')
    spec = rearrange(spec, '(b c) f t -> b f t c', c = 2)
    return spec

def extract_output(spec,mask,mix_args):
    b, c, t = mix_args
    spec = rearrange(spec, 'b f t c -> (b c) f t ', c = 2)
    est_specs = []
    est_wavs = []
    for i in range(mask.shape[0]):
        est_spec = mask[i] * spec
        est_wav = torch_imdct(est_spec, sample_length=t, window_length=2048,
                              window_type='kbd')
        est_wav = rearrange(est_wav, '(b c) t -> b c t', b=b, c=c)
        est_wav = est_wav[:, :, :t]
        est_spec = rearrange(est_spec, '(b c) f t -> b c f t', c=c)
        est_specs.append(est_spec)
        est_wavs.append(est_wav)
    est_specs = torch.stack(est_specs)
    est_wavs = torch.stack(est_wavs)
    return est_specs, est_wavs

def calculate_gt_masks_mixOfMix(sources, target_stems):
    gt_masks = []
    for j in range(2):
        mix_spec = rearrange(prepare_input_wav(sources[:, j]), 'b f t c -> (b c) f t')
        target_spec = rearrange(prepare_input_wav(target_stems[:, j]), 'b f t c -> (b c) f t')
        gt_mask = target_spec / mix_spec
        gt_masks.append(gt_mask)
    gt_masks = torch.stack(gt_masks)
    return gt_masks

In [8]:
mixture_spec = prepare_input_wav(mixture)

pred_mask = model(mixture_spec)

Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Batch size: 4
Outputs before : torch.Size([32, 1023, 45])
Outputs after : torch.Size([4, 8, 1023, 45])


In [9]:
print(pred_mask.shape)

torch.Size([4, 8, 1023, 45])
