In [27]:
import torch
from torch import nn

import numpy as np
import scipy

In [28]:
def frequency_mask(num_bands, up_factor, down_factor):
    
    up_mask = np.zeros((num_bands, num_bands))
    down_mask = np.zeros((num_bands, num_bands))
    
    for i in range(num_bands):
        up_mask[i, : i + 1] = up_factor ** np.arange(i, -1, -1)
        down_mask[i, i :] = down_factor ** np.arange(num_bands - i)
    
    return down_mask @ up_mask


def rect_fb(band_limits, num_bins=None):
    num_bands = len(band_limits) - 1
    if num_bins is None:
        num_bins = band_limits[-1]
    
    fb = np.zeros((num_bands, num_bins))
    for i in range(num_bands):
        fb[i, band_limits[i]:band_limits[i+1]] = 1
        
    return fb

In [50]:
class MOCLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.band_limits = [0, 2, 4, 6, 7, 9, 11, 13, 15, 18, 22, 26, 31, 36, 43, 51, 60, 75]
        self.num_bands = len(self.band_limits) - 1
        self.window_size = 160
        self.nfft = 160
        self.hop_size = 80
        self.mask_gain = 0.1
        self.tmask_gain = 0.5
        self.fb = nn.Parameter(
            torch.from_numpy(rect_fb(self.band_limits, num_bins=81)),
            requires_grad=False
        )
        self.f_mask = nn.Parameter(
            torch.from_numpy(frequency_mask(self.num_bands, 0.1, 0.03)),
            requires_grad=False
        )
        self.window = nn.Parameter(
            torch.from_numpy(scipy.signal.get_window('hamming', self.window_size)),
            requires_grad=False
        )
    
    def forward(self, x_ref, x_deg):
        X_ref = torch.stft(x_ref * 2**15, self.nfft, self.hop_size, window=self.window, center=False, onesided=True, return_complex=True)
        X_deg = torch.stft(x_deg * 2**15, self.nfft, self.hop_size, window=self.window, center=False, onesided=True, return_complex=True)
        
        psd_ref = torch.abs(X_ref.permute(0, 2, 1)) ** 2 + 100000
        psd_deg = torch.abs(X_deg.permute(0, 2, 1)) ** 2 + 100000
        
        # frequency masking
        be_ref = (psd_ref @ self.fb.T) / self.fb.sum(dim=1).unsqueeze(0)
        mask = be_ref @ self.f_mask.T
        
        # temporal masking
        for i in range(1, mask.size(1)):
            mask[:, i, :] = mask[:, i, :] + self.tmask_gain * mask[:, i, :]
        
        # apply mask
        lifted_mask = mask @ self.fb
        masked_psd_ref = psd_ref + self.mask_gain * lifted_mask
        masked_psd_deg = psd_deg + self.mask_gain * lifted_mask
        
        # 2-frame average
        masked_psd_ref = masked_psd_ref[:, 1:] + masked_psd_ref[:, :-1] 
        masked_psd_deg = masked_psd_deg[:, 1:] + masked_psd_deg[:, :-1]
        
        # calculate distortion
        re = masked_psd_ref / masked_psd_deg
        im = re - torch.log(re) - 1
        Eb = (im @ self.fb.T) / self.fb.sum(dim=1)
        Ef = torch.mean(Eb ** 2, dim = 2)
        err = torch.mean(Ef ** 4, dim=1) ** (1/16)
        loss = torch.mean(err)
        
        return loss
        
        
        

In [51]:
x = torch.randn(32, 16000)
y = torch.randn(32, 16000)

In [52]:
loss = MOCLoss()

In [54]:
loss(x, x)

tensor(0., dtype=torch.float64)