In [1]:
import os
os.chdir('../')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from IConNet.signal import nextpow2



In [143]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import math
from einops import rearrange, reduce
import opt_einsum as oe
import numpy as np

class LoremNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, downsample_factor=2,
                 stride=1, padding=0, dilation=1, bias=False, groups=1,
                 window_func='learnable', window_k: int=2):

        super().__init__()

        self.in_channels = in_channels
        if kernel_size % 2 == 0: # Forcing the filters to be odd (i.e, perfectly symmetrics)
            kernel_size += 1
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.learnable_params = nn.Parameter(torch.rand(window_k))
        self.fir_filters = nn.Parameter(torch.rand(out_channels, in_channels, kernel_size))
        self.window = nn.Parameter(torch.rand(self.kernel_size), requires_grad=False)
        self.stride = stride
        self.downsample_factor = downsample_factor

    def forward(self, waveforms):
        self.fir_filters = self.fir_filters.to(waveforms.device)
        L = waveforms.shape[-1] // self.downsample_factor
        return F.conv1d(waveforms, self.fir_filters, stride=self.stride)[..., :L]
        

class FIRWinFilters(nn.Module):
    """FIR filter design using the window method. (Ref: scipy.signal.firwin2)
    Forward steps:
        First, linearly interpolate the desired response on a uniform mesh `x`.
        Then adjust the phases of the coefficients so that the first `ntaps` of the
        inverse FFT are the desired filter coefficients.
    """

    @staticmethod
    def generate_firwin_mesh(kernel_size, fs=2):
        """Frequency-domain mesh"""
        nyq = fs/2
        nfreqs = 1 + 2 ** nextpow2(kernel_size)
        mesh_freq = torch.linspace(0.0, nyq, nfreqs) # (out_channels, in_channels, mesh_length) or (H C M)
        shift_freq = torch.exp(-(kernel_size - 1) / 2. * 1.j * torch.pi * mesh_freq / nyq)
        return mesh_freq, shift_freq
    
    def __init__(self, out_channels, in_channels, kernel_size, fs=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.fs = fs
        mesh_freq, shift_freq = self.generate_firwin_mesh(kernel_size, fs)
        self.register_buffer("mesh_freq", mesh_freq)
        self.register_buffer("shift_freq", shift_freq)
        
        lowcut_bands = torch.rand(out_channels, in_channels)
        bandwidths = torch.rand(out_channels, in_channels)
        self.lowcut_bands = nn.Parameter(lowcut_bands)
        self.bandwidths = nn.Parameter(bandwidths)
    
    def forward(self, W):
        """
        Args:
            W: (out_channels, in_channels, kernel_size) or (H C K). Time-domain windows.
            bandwidths, lowcut_bands:  (out_channels, in_channels) or (H C)
        Returns:
            W: FIR filters
        """        
        mesh1 = self.mesh_freq - self.lowcut_bands.abs()[..., None]
        mesh2 = mesh1 - self.bandwidths.abs()[..., None]
        x_freq = torch.logical_and(mesh1 >=0, mesh2 <= 0).float() # (H C M)
        firwin_freq = oe.contract('hcm,m->hcm', x_freq, self.shift_freq) 
        firwin_time = torch.fft.irfft(firwin_freq)[..., :self.kernel_size] # (H C K)
        W = firwin_time * W
        return W

    def right_inverse(self, W):
        return W

In [144]:
m = LoremNet(2, 6, 16, 2, stride=1)

In [145]:
parametrize.register_parametrization(m, "fir_filters", FIRWinFilters(
    out_channels=m.out_channels, in_channels=m.in_channels, kernel_size=m.kernel_size, fs=2
))

ParametrizedLoremNet(
  (parametrizations): ModuleDict(
    (fir_filters): ParametrizationList(
      (0): FIRWinFilters()
    )
  )
)

In [146]:
m.fir_filters

tensor([[[ 2.7017e-03, -3.5282e-02, -5.0047e-03,  8.5210e-03, -1.5822e-02,
          -2.8166e-03,  1.7565e-02, -6.3762e-02,  1.9569e-01, -2.2612e-01,
           1.3910e-01, -4.6428e-04, -1.4783e-02,  3.9140e-02, -3.8468e-03,
          -3.1967e-02,  3.5957e-02],
         [-8.0601e-03,  6.0330e-02, -2.5068e-02, -1.8187e-03, -2.1962e-02,
           1.4407e-01, -3.4229e-02, -1.0448e-01,  1.0766e-01, -1.7890e-01,
          -1.3118e-02,  2.1232e-02, -4.2757e-02, -2.1469e-05, -3.4763e-02,
           8.9455e-02, -3.2677e-02]],

        [[-1.1157e-02,  1.7915e-02, -5.9321e-03,  2.3713e-02, -2.8974e-03,
          -2.6829e-03,  6.7615e-03, -1.5826e-02,  1.5835e-02, -7.8370e-03,
           1.4101e-02, -2.7006e-03, -8.5758e-03,  9.8342e-03, -5.6956e-03,
           1.3177e-02, -1.4380e-02],
         [-3.7101e-02,  7.6319e-03,  7.7628e-03,  6.2964e-04, -7.7856e-02,
           2.5357e-02,  2.2011e-02, -1.1627e-01,  4.8395e-02, -7.1270e-02,
           7.3935e-03,  1.1059e-01, -8.7958e-02,  1.6055e-03, 

In [147]:
params = [p for p in m.parameters()]
params

[Parameter containing:
 tensor([0.4004, 0.0141], requires_grad=True),
 Parameter containing:
 tensor([0.2773, 0.5896, 0.6565, 0.6562, 0.6065, 0.4650, 0.6236, 0.4431, 0.6235,
         0.0043, 0.5421, 0.5090, 0.4531, 0.9472, 0.0734, 0.8139, 0.9858]),
 Parameter containing:
 tensor([[[0.0716, 0.9471, 0.9486, 0.1466, 0.2376, 0.5390, 0.1249, 0.2334,
           0.5964, 0.8279, 0.9894, 0.0889, 0.2220, 0.6733, 0.7291, 0.8581,
           0.9532],
          [0.1511, 0.6737, 0.6607, 0.9928, 0.2671, 0.8146, 0.5751, 0.4341,
           0.2650, 0.7433, 0.2204, 0.1201, 0.5200, 0.0117, 0.9162, 0.9990,
           0.6125]],
 
         [[0.5049, 0.5991, 0.1935, 0.9816, 0.2423, 0.8759, 0.3895, 0.5742,
           0.5067, 0.2844, 0.8122, 0.8817, 0.7171, 0.4071, 0.1858, 0.4406,
           0.6508],
          [0.4918, 0.2314, 0.6491, 0.0293, 0.7008, 0.2038, 0.7624, 0.4435,
           0.1291, 0.2718, 0.2561, 0.8888, 0.7918, 0.0746, 0.8551, 0.9672,
           0.5225]],
 
         [[0.5981, 0.8690, 0.7636, 0.1851,

In [153]:
data = torch.rand(1, 2, 200)
prediction = m(data)
labels = torch.rand(1, 6, 100)
criterion = nn.MSELoss()
loss = criterion(prediction, labels)
loss.backward()

In [154]:
prediction.shape

torch.Size([1, 6, 100])

In [155]:
optim = torch.optim.Adamax(m.parameters(), lr=0.1)
optim.step() 

In [159]:
data = torch.rand(50000, 2, 200)
labels = torch.rand(50000, 6, 100)
prediction = m(data)
loss = criterion(prediction, labels)
loss.backward()
optim.step() 

In [160]:
loss

tensor(0.3044, grad_fn=<MseLossBackward0>)

In [161]:
for p in m.parameters():
    print(p)

Parameter containing:
tensor([0.4004, 0.0141], requires_grad=True)
Parameter containing:
tensor([0.2773, 0.5896, 0.6565, 0.6562, 0.6065, 0.4650, 0.6236, 0.4431, 0.6235,
        0.0043, 0.5421, 0.5090, 0.4531, 0.9472, 0.0734, 0.8139, 0.9858])
Parameter containing:
tensor([[[ 0.0834, -0.0772, -0.0797,  0.0855, -0.0751, -0.0797,  0.0948,
          -0.0636,  0.1253, -0.0642,  0.0949, -0.0795, -0.0749,  0.0858,
          -0.0799, -0.0773,  0.0832],
         [-0.0762,  0.0894, -0.0774, -0.0804, -0.0741,  0.0995, -0.0759,
          -0.0654,  0.1430, -0.0654, -0.0758,  0.0995, -0.0737, -0.0804,
          -0.0774,  0.0892, -0.0760]],

        [[-0.0791,  0.0835, -0.0782,  0.0830, -0.0799, -0.0806,  0.0823,
          -0.0787,  0.0835, -0.0787,  0.0826, -0.0808, -0.0795,  0.0828,
          -0.0781,  0.0835, -0.0794],
         [-0.0752,  0.0840,  0.0823,  0.0830, -0.0724,  0.0940,  0.0838,
          -0.0647,  0.1348, -0.0645,  0.0834,  0.0936, -0.0725,  0.0834,
           0.0826,  0.0842, -0.0753]

In [211]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fb0b8ffbe80>

In [540]:
from einops import repeat 
from IConNet.fftconv import fft_conv_complex2 as fft_conv

class LoremNet2(nn.Module):
    @staticmethod
    def generate_firwin_mesh(kernel_size, fs=2):
        """
        Returns:
            mesh_freq: (out_channels, in_channels, mesh_length) or (H C M).
                Frequency-domain mesh.
            shift_freq: (H C M). To adjust the phases of the coefficients so that the first
                window coefficient of the inverse FFT are the desired filter coefficients.
        """
        nyq = fs/2
        nfreqs = 1 + 2 ** nextpow2(kernel_size)
        mesh_freq = torch.linspace(0.0, nyq, nfreqs)
        shift_freq = torch.exp(-(kernel_size - 1) / 2. * 1.j * torch.pi * mesh_freq / nyq)
        return mesh_freq, shift_freq
        
    def __init__(self, in_channels, out_channels, kernel_size, 
                 fs=2, stride=1, padding=0, dilation=1, bias=False, groups=1,
                 window_func='learnable', window_k: int=2):

        super().__init__()

        self.in_channels = in_channels
        if kernel_size % 2 == 0: # Forcing the filters to be odd (i.e, perfectly symmetrics)
            kernel_size += 1
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.window_k = window_k
        self.window_params = nn.Parameter(torch.rand(window_k))
        self.window = torch.rand(self.kernel_size, requires_grad=True)
        self.stride = stride

        self.fir_filters = torch.rand(out_channels, in_channels, kernel_size, 
                                      requires_grad=True)
        self.fs = fs
        self.mesh_freq, self.shift_freq = self.generate_firwin_mesh(kernel_size, fs)
        
        self.lowcut_bands = nn.Parameter(torch.rand(out_channels, in_channels))
        self.bandwidths = nn.Parameter(torch.rand(out_channels, in_channels))

    def forward(self, waveforms, trainable=True):
        device = waveforms.device
        
        self.window_params = self.window_params.to(device)
        self.lowcut_bands = self.lowcut_bands.to(device)
        self.bandwidths = self.bandwidths.to(device)

        # generate general cosine window from win_params
        k = torch.linspace(0, 2*math.pi, self.kernel_size, 
                           requires_grad=trainable, device=device)
        i = torch.arange(self.window_k, dtype=torch.float, device=device)[..., None]
        self.window = (self.window_params[..., None] * (-1)**i * torch.cos(i * k)).sum(0).to(device)

        # interpolate the desired filter coefficients in freq domain into the freq mesh
        # example: mesh [0. .25 .5 .75 1.], low1=.1 low2=.6 => [0. 1. 1. 0. 0.]
        self.fir_filters = repeat(self.window, 'k -> h c k', 
                                  h=self.out_channels, c=self.in_channels)
        m = self.mesh_freq.shape[-1]
        self.mesh1 = repeat(self.lowcut_bands, 'h c -> h c m', m=m)
        self.mesh1 = self.mesh_freq - self.mesh1.abs()
        self.mesh2 = self.mesh1 - self.bandwidths.abs()[..., None]

        self.mesh1 = torch.clamp(torch.exp(self.mesh1), min=0., max=1.) # torch.where(mesh1 >= 0., 1., 0.)
        self.mesh2 = torch.clamp(torch.exp(-self.mesh2), min=0., max=1.) # torch.where(mesh2 <= 0., 1., 0.)
        self.x_freq = self.mesh1 * self.mesh2 #  torch.logical_and(mesh1, mesh2).float()
        self.firwin_freq = oe.contract('hcm,m->hcm', self.x_freq, self.shift_freq) 
        
        # bring the firwin to time domain & multiply with the time-domain window 
        self.firwin_time = torch.fft.irfft(self.firwin_freq)[..., :self.kernel_size] 
        self.fir_filters = oe.contract('hck,hck->hck', self.fir_filters, self.firwin_time)

        # stride is downsampling factor 
        L = waveforms.shape[-1] // self.stride
        p = self.stride - waveforms.shape[-1] % self.stride
        padding = (0,p)
        X = F.pad(waveforms, padding)
        X = fft_conv(X, self.fir_filters, stride=self.stride)[..., :L]
        return X
        

In [541]:
aaa = torch.tensor([1., 0., 1.], requires_grad=True)
bbb = torch.tensor([1., 0., 1.], requires_grad=True)
ccc = torch.logical_and(aaa, bbb, out=torch.empty(3, dtype=torch.float, requires_grad=True))
ccc.requires_grad

True

In [542]:
m2 = LoremNet2(2, 6, 16, 2, stride=2)
criterion = nn.MSELoss()
optim = torch.optim.Adamax(m2.parameters(), lr=0.1)
optim.zero_grad()

In [543]:
for p in m2.parameters():
    print(p)

Parameter containing:
tensor([0.6944, 0.4354], requires_grad=True)
Parameter containing:
tensor([[0.0331, 0.1273],
        [0.2398, 0.7785],
        [0.8730, 0.1917],
        [0.1795, 0.0370],
        [0.4132, 0.6983],
        [0.8793, 0.6502]], requires_grad=True)
Parameter containing:
tensor([[0.8280, 0.9202],
        [0.9723, 0.5889],
        [0.4802, 0.7456],
        [0.0258, 0.9927],
        [0.2299, 0.1502],
        [0.4180, 0.1721]], requires_grad=True)


In [544]:
data = torch.rand(1, 2, 200)
labels = torch.rand(1, 6, 100)
prediction = m2(data)
loss = criterion(prediction, labels)
loss.backward(retain_graph=True)
optim.step() 

In [545]:
for p in m2.parameters():
    print(p.grad)

tensor([3.7855, 4.0719])
tensor([[-0.2930, -0.5516],
        [-0.3282, -0.1903],
        [-0.1880, -0.3239],
        [-0.5459, -0.2780],
        [-0.2405, -0.1816],
        [-0.1251, -0.1470]])
tensor([[-0.0004,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0038],
        [ 0.0232,  0.0000],
        [ 0.0091,  0.0005],
        [ 0.0000,  0.0005]])


In [546]:
m2.firwin_time.grad

  m2.firwin_time.grad


In [549]:
m2.window_params.grad

tensor([3.7855, 4.0719])

In [550]:
m2.bandwidths.grad

tensor([[-0.0004,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0038],
        [ 0.0232,  0.0000],
        [ 0.0091,  0.0005],
        [ 0.0000,  0.0005]])

In [551]:
data = torch.rand(10, 2, 200)
labels = torch.rand(10, 6, 100)
optim.zero_grad()
prediction = m2(data)
loss = criterion(prediction, labels)
loss.backward(retain_graph=True)
optim.step() 

In [552]:
data = torch.rand(50000, 2, 200)
labels = torch.rand(50000, 6, 100)
optim.zero_grad()
prediction = m2(data)
loss = criterion(prediction, labels)
loss.backward(retain_graph=True)
optim.step() 

In [553]:
for p in m2.parameters():
    print(p.grad)

tensor([1.3584, 1.4856])
tensor([[-0.1640, -0.1588],
        [-0.0712, -0.0446],
        [-0.0411, -0.0742],
        [-0.1408, -0.1572],
        [-0.0503, -0.0416],
        [-0.0209, -0.0256]])
tensor([[ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0014],
        [-0.0088,  0.0000],
        [ 0.0036, -0.0009],
        [ 0.0000, -0.0010]])


In [554]:
data = torch.rand(50000, 2, 200)
labels = torch.rand(50000, 6, 100)
optim.zero_grad()
prediction = m2(data)
loss = criterion(prediction, labels)
loss.backward(retain_graph=True)

In [555]:
for p in m2.parameters():
    print(p.grad)

tensor([0.7697, 0.8349])
tensor([[-0.0833, -0.0822],
        [-0.0294, -0.0199],
        [-0.0176, -0.0298],
        [-0.0702, -0.0800],
        [-0.0184, -0.0166],
        [-0.0052, -0.0063]])
tensor([[ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0008],
        [-0.0025,  0.0000],
        [-0.0026, -0.0002],
        [ 0.0000, -0.0005]])


In [556]:
for p in m2.parameters():
    print(p)

Parameter containing:
tensor([0.4521, 0.1922], requires_grad=True)
Parameter containing:
tensor([[0.3120, 0.3629],
        [0.4639, 1.0074],
        [1.0975, 0.4201],
        [0.4091, 0.3104],
        [0.6367, 0.9284],
        [1.0974, 0.8710]], requires_grad=True)
Parameter containing:
tensor([[ 1.0054,  0.9202],
        [ 0.9723,  0.5889],
        [ 0.4802,  0.4959],
        [-0.1246,  0.9927],
        [-0.0145, -0.0414],
        [ 0.4180, -0.0177]], requires_grad=True)


In [557]:
data = torch.rand(50000, 2, 200)
labels = torch.rand(50000, 6, 100)
optim.zero_grad()
prediction = m2(data)
loss = criterion(prediction, labels)
loss.backward(retain_graph=True)
optim.step()
for p in m2.parameters():
    print(p)

Parameter containing:
tensor([0.4013, 0.1410], requires_grad=True)
Parameter containing:
tensor([[0.3775, 0.4091],
        [0.5042, 1.0497],
        [1.1379, 0.4617],
        [0.4526, 0.3736],
        [0.6763, 0.9706],
        [1.1337, 0.9083]], requires_grad=True)
Parameter containing:
tensor([[ 1.0266,  0.9202],
        [ 0.9723,  0.5889],
        [ 0.4802,  0.4426],
        [-0.1293,  0.9927],
        [-0.0521, -0.0411],
        [ 0.4180, -0.0176]], requires_grad=True)


In [558]:
data = torch.rand(50000, 2, 200)
labels = torch.rand(50000, 6, 100)
optim.zero_grad()
prediction = m2(data)
loss = criterion(prediction, labels)
loss.backward(retain_graph=True)
optim.step()
for p in m2.parameters():
    print(p)

Parameter containing:
tensor([0.3604, 0.0998], requires_grad=True)
Parameter containing:
tensor([[0.4303, 0.4458],
        [0.5353, 1.0825],
        [1.1692, 0.4939],
        [0.4870, 0.4245],
        [0.7065, 1.0030],
        [1.1608, 0.9361]], requires_grad=True)
Parameter containing:
tensor([[ 1.0427,  0.9202],
        [ 0.9723,  0.5889],
        [ 0.4802,  0.3997],
        [-0.1315,  0.9927],
        [-0.0767, -0.0408],
        [ 0.4180, -0.0183]], requires_grad=True)
