## 1. Model Defnition

In [None]:
from utils_ismir import *
from intermediate_layers import * 
from frameworks import *
import torch.nn as nn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
musdb_loader = MusdbLoaer(device=device)

def mk_tfc_tif (input_c, output_c, f, i):
    print('intermediate\t at level', i, 'with TFC_TIF')
    return TFC_TIF(in_channels=input_c, num_layers=5, gr=output_c, kt=3, kf=3, f=f, bn_factor=16, bias=True)

t_scale = [0,2,4,6]

def mk_tfc_tif_ds (i, f, t_scale=t_scale):
    scale = (2,2) if i in t_scale else (1,2)
    print('downsampling\t at level', i, 'with scale(T, F): ', scale, ', F_scale: ', f, '->', f//scale[-1])
    ds = nn.Sequential(
        nn.Conv2d(in_channels=24, out_channels=24, kernel_size=scale, stride=scale),
        nn.BatchNorm2d(24)
    )
    return ds, f//scale[-1]

def mk_tfc_tif_us (i, f, n, t_scale=t_scale):
    scale = (2,2) if i in [n -1 -s for s in  t_scale] else (1,2)

    print('upsampling\t at level', i, 'with scale(T, F): ', scale, ', F_scale: ', f, '->', f*scale[-1])
    us = nn.Sequential(
        nn.ConvTranspose2d(in_channels=24, out_channels=24, kernel_size=scale, stride=scale),
        nn.BatchNorm2d(24)
    )
    return us, f*scale[-1]


model = U_Net_Framework(
    musdb_loader, est_mode='cac_mapping', internal_channels=24, num_blocks=17, 
    mk_block_f=mk_tfc_tif, mk_ds_f=mk_tfc_tif_ds, mk_us_f=mk_tfc_tif_us
).cuda()

## 2. Load Pretrained Model

In [None]:
!ls pretrained | grep cac_tfc_tif_17_vocals

In [None]:
pretrained_path = 'pretrained/cac_tfc_tif_17_vocals_top2.pt'
pretrained_params = torch.load(pretrained_path)
model.load_state_dict(pretrained_params)

## 3. Evaluation: Musdb18 Benchmark

### 3.1 SDR performance of the pretrained U-Net with 17 TFC_TIF blocks

In [None]:
def separator (mix):
    return separate(musdb_loader, model, mix, batch_size=16)

sdrs = eval_testset(musdb_loader, separator, target_name='vocals')

### 3.2. SDR performance of the pretrained UMX

In [None]:
!git clone https://github.com/sigsep/open-unmix-pytorch
%cd open-unmix-pytorch
%mv test.py umx_test.py
!pip install norbert

In [None]:
import umx_test as umx
musdb_loader = MusdbLoaer(musdb_path='../data/musdb18/', device=device)

def separator(mix):
    
    est = umx.separate(
        audio=mix.T,
        targets=['vocals', 'drums', 'bass', 'other'], 
        model_name='umx',
        device=device
    )
    
    return est['vocals'].T

umx_sdr = eval_testset(musdb_loader, separator, target_name='vocals')