In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
import torch.nn.quantized as nnq
import torch.quantization as tq

In [3]:
def check_perf(y, qy, title='= Report =', show=True):
    r"""
    Unscaled Reports:
        - MAE: Mean Absolute Error
        - MaxAE: Maximum Absolute Error
        - MSE: Mean Square Error
        - MaxSE: Maximum Absolute Error
    Scaled reports are the same as above scaled by either (y.max() - y.min()) or by (y.max() - y.min())^2.
    Power reports:
        - SNR: Signal-to-Noise ratio, computed as (y^2).mean() / MSE
        - SNR(db): Signal-to-Noise raio, computed as 10 * log_10(SNR)
    """
    if qy.is_quantized:
        qy = qy.dequantize()
    diff = y - qy
    ret = {}
    ret['MAE'] = diff.abs().mean().item()
    ret['MaxAE'] = diff.abs().max().item()
    ret['MSE'] = diff.square().mean().item()
    ret['MaxSE'] = diff.square().max().item()
    
    y_range = y.max() - y.min().item()
    y_range2 = y_range ** 2

    for key in ['MAE', 'MaxAE']:
        ret[key + '/|y|'] = ret[key] / y_range
    for key in ['MSE', 'MaxSE']:
        ret[key + '/|y|^2'] = ret[key] / y_range2
        
    mse = ret['MSE']
    if mse == 0:
        mse = 1e-15
    ret['SNR'] = y.square().mean().item() / mse
    ret['SNR(db)'] = 10 * np.log10(ret['SNR'])
    
    if show:
        print(f'{title:^24}')
        print('{:^24}'.format('Un-scaled'))
        for key in ['MAE', 'MaxAE', 'MSE', 'MaxSE']:
            value = ret[key]
            print(f'{key:.<16}{value:.2e}')
        print('{:^24}'.format('Scaled'))
        for key in ['MAE/|y|', 'MaxAE/|y|', 'MSE/|y|^2', 'MaxSE/|y|^2']:
            value = ret[key]
            print(f'{key:.<16}{value:.2e}')
        print('{:^24}'.format('Power'))
        for key in ['SNR', 'SNR(db)']:
            value = ret[key]
            print(f'{key:.<16}{value:.2e}')
    return ret

def qparams_min_max(fmin, fmax, qtype):
    qinfo = torch.iinfo(qtype)
    qmin = qinfo.min
    qmax = qinfo.max
    
    scale = (fmax - fmin + 1) / (qmax - qmin)
    zero_point = int(round(qmin - fmin / scale))
    
    return scale, zero_point, qtype

def qparams(x, qtype):
    fmin = min(0, x.min().item())
    fmax = max(0, x.max().item())
    return qparams_min_max(fmin, fmax, qtype)

def quantize(x, qtype):
    s, z, qtype = qparams(x, qtype)
    return torch.quantize_per_tensor(x, s, z, qtype)

In [4]:
batch = 16
channels = 1
length = 1024

x = torch.randn((batch, channels, length))
s, zp, qtype = qparams(x, torch.qint8)
qx = quantize(x, torch.quint8)

In [5]:
def sinc(x):
    return torch.where(x == 0,
                       torch.tensor(1., device=x.device, dtype=x.dtype),
                       torch.sin(x) / x)

def symetric_hann(length: int) -> torch.Tensor:
    """
    torchscript doesn't support `hann_window`, simple re-implementation.
    """
    x = torch.linspace(-0.5, 0.5, length)
    return torch.cos(math.pi * x)**2

In [6]:
# Quantizable layers rewrite
import math

class Upsample2Layer(nn.Module):  
    def __init__(self, in_channels: int = 1, zeros: int = 56):
        super(Upsample2Layer, self).__init__()
        self.zeros = zeros
        kernel = self.make_kernel()
        self.conv = nn.Conv1d(in_channels, in_channels, zeros, bias=False, padding=zeros)
        self.conv.weight = nn.Parameter(kernel)
    
    @staticmethod
    def make_kernel(zeros: int = 56):
        win = symetric_hann(4 * zeros + 1)
        winodd = win[1::2]
        t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
        t *= math.pi
        kernel = (sinc(t) * winodd).view(1, 1, -1)
        return kernel
        
    def forward(self, x: torch.Tensor):
        b, c, l = x.shape
        y = x.view(-1, 1, l)
        y = self.conv(y)
        y = y[:, :, 1:]
        y = y.view(b, c, l)
        y = torch.stack([x, y], dim=-1)
        return y.view(b, c, -1)
    
    @classmethod
    def from_float(cls, mod):
        new_mod = cls()
        new_mod.zeros = mod.zeros
        new_mod.conv = nnq.Conv1d.from_float(mod.conv)
        return cls

In [7]:
upsample = Upsample2Layer(channels)
y = upsample(x)

upsample.qconfig = tq.default_qconfig
q_upsample = tq.prepare(upsample, inplace=False)
q_upsample(x)  # calibrate
tq.convert(q_upsample, inplace=True)
qy = q_upsample(qx)

perf = check_perf(y, qy, show=True, title="= Upsample 2x errors =")

 = Upsample 2x errors = 
       Un-scaled        
MAE.............1.59e-02
MaxAE...........1.08e-01
MSE.............4.43e-04
MaxSE...........1.16e-02
         Scaled         
MAE/|y|.........1.86e-03
MaxAE/|y|.......1.25e-02
MSE/|y|^2.......6.02e-06
MaxSE/|y|^2.....1.57e-04
         Power          
SNR.............2.26e+03
SNR(db).........3.35e+01


  reduce_range will be deprecated in a future release of PyTorch."


In [8]:
class Downsample2Layer(nn.Module):  
    def __init__(self, in_channels: int = 1, zeros: int = 56):
        super(Downsample2Layer, self).__init__()
        self.zeros = zeros
        
        kernel = self.make_kernel()
        self.conv = nn.Conv1d(in_channels, in_channels, zeros, bias=False, padding=zeros)
        self.conv.weight = nn.Parameter(kernel)
        
        self.pad = nn.ConstantPad1d((0, 1), 0)
        self.mul = nnq.FloatFunctional()
        self.add = nnq.FloatFunctional()
    
    @staticmethod
    def make_kernel(zeros: int = 56):
        win = symetric_hann(4 * zeros + 1)
        winodd = win[1::2]
        t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros)
        t *= math.pi
        kernel = (sinc(t) * winodd).view(1, 1, -1)
        return kernel

    def forward(self, x: torch.Tensor):
        if x.shape[-1] % 2:
            x = self.pad(x)
        xeven = x[:, :, ::2]
        xodd = x[:, :, 1::2]
        b, c, l = xodd.shape
        xodd = xodd.reshape((-1, 1, l))
        y = self.conv(xodd)
        y = y[:, :, :-1]
        y = y.view(b, c, l)
        
        y = self.add.add(y, xeven)
        y = self.mul.mul_scalar(y, 0.5)
        return y
    
    @classmethod
    def from_float(cls, mod):
        new_mod = cls()
        new_mod.zeros = mod.zeros
        new_mod.conv = nnq.Conv1d.from_float(mod.conv)
        return cls

In [9]:
downsample = Downsample2Layer(channels)
y = downsample(x)

# # qy = quantize(downsample(qx.dequantize()), qx.dtype).dequantize()

downsample.qconfig = tq.default_qconfig
q_downsample = tq.prepare(downsample, inplace=False)
q_downsample(x)  # calibrate
tq.convert(q_downsample, inplace=True)
qy = q_downsample(qx)

perf = check_perf(y, qy, show=True, title="= Downsample 2x errors =")

= Downsample 2x errors =
       Un-scaled        
MAE.............1.51e-02
MaxAE...........5.69e-02
MSE.............3.49e-04
MaxSE...........3.24e-03
         Scaled         
MAE/|y|.........2.76e-03
MaxAE/|y|.......1.04e-02
MSE/|y|^2.......1.16e-05
MaxSE/|y|^2.....1.08e-04
         Power          
SNR.............1.44e+03
SNR(db).........3.16e+01


In [10]:
# Implemented in https://github.com/pytorch/pytorch/pull/42443
class GLU(nn.Module):
    def __init__(self, dim=-1):
        super(GLU, self).__init__()
        self.dim = dim

    def forward(self, qx):
        if qx.is_quantized:
            return self._qforward(qx)
        else:
            return self._forward(qx)
    
    def _forward(self, x):
        return F.glu(x, self.dim)
    
    def _qforward(self, qx):
        x = qx.dequantize()
        y = F.glu(x, self.dim)
        qy = torch.quantize_per_tensor(y, qx.q_scale(), qx.q_zero_point(), qx.dtype)
        return qy

glu = GLU(-1)
y = glu(x)

# Prep/Convert sohuld have no effect
glu.qconfig = tq.default_qconfig
q_glu = tq.prepare(glu, inplace=False)
q_glu(x)
tq.convert(q_glu, inplace=True)
qy = q_glu(qx)

perf = check_perf(y, qy, show=True, title="= GLU errors =")

     = GLU errors =     
       Un-scaled        
MAE.............1.05e-02
MaxAE...........3.77e-02
MSE.............1.57e-04
MaxSE...........1.42e-03
         Scaled         
MAE/|y|.........1.99e-03
MaxAE/|y|.......7.18e-03
MSE/|y|^2.......5.72e-06
MaxSE/|y|^2.....5.16e-05
         Power          
SNR.............1.86e+03
SNR(db).........3.27e+01


In [11]:
# Implemented in https://github.com/pytorch/pytorch/pull/40371
class ConvTranspose1d(nn.ConvTranspose1d):
    def __init__(self, *args, **kwargs):
        super(ConvTranspose1d, self).__init__(*args, **kwargs)
        self.quant_stub = tq.QuantStub()
        self.dequant_stub = tq.DeQuantStub()
        
    def forward(self, x):
        x = self.dequant_stub(x)
        x = super(ConvTranspose1d, self).forward(x)
        x = self.quant_stub(x)
        return x        

In [12]:
conv_transpose = ConvTranspose1d(channels, channels, kernel_size=3, stride=2)
y = conv_transpose(x)

# Prep/Convert sohuld have no effect
conv_transpose.qconfig = tq.default_qconfig
q_conv_transpose = tq.prepare(conv_transpose, inplace=False)
q_conv_transpose(x)
tq.convert(q_conv_transpose, inplace=True)
qy = q_conv_transpose(qx)

perf = check_perf(y, qy, show=True, title="= ConvTranspose1d errors =")

= ConvTranspose1d errors =
       Un-scaled        
MAE.............5.90e-03
MaxAE...........1.83e-02
MSE.............4.87e-05
MaxSE...........3.34e-04
         Scaled         
MAE/|y|.........2.10e-03
MaxAE/|y|.......6.48e-03
MSE/|y|^2.......6.13e-06
MaxSE/|y|^2.....4.20e-05
         Power          
SNR.............2.14e+03
SNR(db).........3.33e+01


# Reuse the facebookreesearch/demucs + modify it

In [13]:
import os
import sys

DEMUCS_PATH = os.path.join('~', 'Git', 'demucs')
DEMUCS_PATH = os.path.expanduser(DEMUCS_PATH)
if DEMUCS_PATH not in sys.path:
    sys.path.append(DEMUCS_PATH)

In [101]:
from torchsummary import summary
from demucs.model import Demucs

model = Demucs(audio_channels=1, depth=5, growth=1.5)
model

Demucs(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(1, 64, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (1): Sequential(
      (0): Conv1d(64, 96, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(96, 192, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (2): Sequential(
      (0): Conv1d(96, 144, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(144, 288, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (3): Sequential(
      (0): Conv1d(144, 216, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(216, 432, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (4): Sequential(
      (0): Conv1d(216, 324, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(324, 648, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
  )
  (decoder): ModuleList(
    (0): Sequential(
