In [171]:
# Convolution using mxnet  ### x w
from __future__ import print_function
import numpy as np
import itertools, time
import torch
print('PyTorch version:', torch.__version__)

import scipy
from scipy.signal import convolve, fftconvolve, tukey, deconvolve



def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
    # First figure out what the size of the output should be
    N, C, H, W = x_shape
    assert (H + 2 * padding - field_height) % stride == 0
    assert (W + 2 * padding - field_height) % stride == 0
    out_height = (H + 2 * padding - field_height) // stride + 1
    out_width = (W + 2 * padding - field_width) // stride + 1

    i0 = np.repeat(np.arange(field_height), field_width)
    i0 = np.tile(i0, C)
    i1 = stride * np.repeat(np.arange(out_height), out_width)
    j0 = np.tile(np.arange(field_width), field_height * C)
    j1 = stride * np.tile(np.arange(out_width), out_height)
    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
    j = j0.reshape(-1, 1) + j1.reshape(1, -1)

    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)

    return (k, i, j)


def im2col_indices(x, field_height, field_width, padding=1, stride=1):
    """ An implementation of im2col based on some fancy indexing """
    # Zero-pad the input
    p = padding
    x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')

    k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
                               stride)

    cols = x_padded[:, k, i, j]
    C = x.shape[1]
    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
    return cols


def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1,
                   stride=1):
    """ An implementation of col2im based on fancy indexing and np.add.at """
    N, C, H, W = x_shape
    H_padded, W_padded = H + 2 * padding, W + 2 * padding
    x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
    k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding,
                               stride)
    cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
    cols_reshaped = cols_reshaped.transpose(2, 0, 1)
    np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
    if padding == 0:
        return x_padded
    return x_padded[:, :, padding:-padding, padding:-padding]

def conv_forward_naive(x, w, b, conv_param):
    """
    A naive implementation of the forward pass for a convolutional layer.
    The input consists of N data points, each with C channels, height H and width
    W. We convolve each input with F different filters, where each filter spans
    all C channels and has height HH and width HH.
    Input:
    - x: Input data of shape (N, C, H, W)
    - w: Filter weights of shape (F, C, HH, WW)
    - b: Biases, of shape (F,)
    - conv_param: A dictionary with the following keys:
    - 'stride': The number of pixels between adjacent receptive fields in the
      horizontal and vertical directions.
    - 'pad': The number of pixels that will be used to zero-pad the input.
    Returns a tuple of:
    - out: Output data, of shape (N, F, H', W') where H' and W' are given by
    H' = 1 + (H + 2 * pad - HH) / stride
    W' = 1 + (W + 2 * pad - WW) / stride
    - cache: (x, w, b, conv_param)
    """
    out = None
    N, C, H, W = x.shape
    F, C, HH, WW = w.shape
    pad, stride = conv_param['pad'], conv_param['stride']
    x_stretched = im2col_indices(x, HH, WW, padding=pad, stride=stride)
    w_stretched = im2col_indices(w, HH, WW, padding=0, stride=1)
    H_prime = 1 + (H + 2*pad - HH) // stride
    W_prime = 1 + (W + 2*pad - WW) // stride
    out_shape = (N, F, H_prime, W_prime)
    out = col2im_indices(w_stretched.T.dot(x_stretched) + b[:,np.newaxis], 
                       out_shape, field_height=1, field_width=1, padding=0, 
                       stride=1)
    cache = (x, w, b, conv_param)
    return out, cache

class Benchmark(): 
    def __init__(self, prefix=None):
        self.prefix = prefix + ' ' if prefix else ''

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, *args):
        print('%stime: %.4f sec' % (self.prefix, time.time() - self.start))



PyTorch version: 1.9.0


# 卷积定理

Ref: [Proofs of Parseval’s Theorem & the Convolution Theorem](http://wwwf.imperial.ac.uk/~jdg/eeft3.pdf)

若 $f_1(t) \leftrightarrow F_1(\omega), f_2(t)\leftrightarrow F_2(\omega)$，则有

$$
\begin{align}
F[f_1(t)\star f_2(t)]&=F_1(\omega)\cdot F_2(\omega)\\
F[f_1(t)\cdot f_2(t)]&=F_1(\omega)\star F_2(\omega)
\end{align}
$$

- 卷积定理

    - ANN out (scratch): 
    
        `x(t) * w(t)`
        
    - np.fft (scratch): 
    
        `abs( ifft[ fft[x(t), mod=2xsize(x)-1] · fft[w(-t), mod=2xsize(x)-1] ] )`
        
    - scipy.fft (scratch):
        
        `abs( ifft[ fft[x(t), mod=2xsize(x)-1] · fft[w(-t), mod=2xsize(x)-1] ] )`

    - np.convolve:
    
        `x(t) * w(-t)`
        
    - scipy.signal.convolve/fftconvolve:
    
        `x(t) * w(-t)`
        
    - mxnet.ndarray.Convolution:
    
        `x(t) * w(t)`
    - torch.nn.Conv1d/torch.nn.Conv2d:

        `x(t) * w(t)`

In [116]:
print('='*10+'\n卷积定理\n'+'='*10)

x = np.arange(10).reshape(1,1,1,10)
w = np.arange(10).reshape(1,1,1,10)
b = np.arange(1)
print('x:',x)
print('w:',w,'b:',b, '\n')

# Convolution using ANN by scratch ### x w
conv_param = {'pad': w.shape[-1]-1,
              'stride': 1}
out, cache = conv_forward_naive(x, w, b, conv_param) # (1, 1, 19, 19)
print('ANN out (scratch):',out.shape,'\n', out[0,0,w.shape[-1]-1])

# Convolution by scratch ### x w[::-1]
print('np.fft (scratch):\n',np.abs(np.fft.ifft( np.fft.fft(x[0,0,0],w.shape[-1]*2-1) * np.fft.fft(w[0,0,0][::-1],w.shape[-1]*2-1) )).astype(np.int32))
print('scipy.fft (scratch):\n',np.abs(scipy.fft.ifft( scipy.fft.fft(x[0,0,0],w.shape[-1]*2-1) * scipy.fft.fft(w[0,0,0][::-1],w.shape[-1]*2-1) )).astype(np.int32))

# Convolution using numpy  ### x w[::-1]
print('np.convolve:\n', np.convolve(x[0,0,0], w[0,0,0][::-1]))

# Convolution using scipy  ### x w[::-1]
print('scipy.signal.convolve:\n', convolve(x[0,0,0], w[0,0,0][::-1]))
print('scipy.signal.fftconvolve:\n', fftconvolve(x[0,0,0], w[0,0,0][::-1]).astype(np.int32))

# Mxnet
out = nd.Convolution(data=nd.array(x), weight=nd.array(w), bias=nd.array(b), kernel=w.shape[-2:],
                                  num_filter=w.shape[1], pad=(0,w.shape[-1]-1))  # 
print('mxnet out:', out.shape, '\n', out[0,0,0].asnumpy().astype(np.int32))

# FFT & iFFT etc. for torch vs numpy
assert np.allclose( torch.fft.fftfreq(4096).numpy() , np.fft.fftfreq(4096) )
assert np.allclose( torch.fft.fft(torch.tensor(w)[0,0,0]).numpy() , np.fft.fft(w[0,0,0]) )
assert np.allclose( torch.fft.ifft(torch.tensor(w)[0,0,0]).numpy() , np.fft.ifft(w[0,0,0]) )
assert np.allclose( torch.fft.rfft(torch.tensor(w)[0,0,0]).numpy() , np.fft.rfft(w[0,0,0]) )
assert np.allclose( torch.fft.irfft(torch.tensor(w)[0,0,0]).numpy() , np.fft.irfft(w[0,0,0]) , atol=1e-7)
assert np.allclose( np.fft.fft2(np.mgrid[:5, :5][0]) , torch.fft.fft2(torch.tensor(np.mgrid[:5, :5][0])) , atol=1e-6)
assert np.allclose( np.fft.ifft2(np.mgrid[:5, :5][0]) , torch.fft.ifft2(torch.tensor(np.mgrid[:5, :5][0])) )

conv = torch.nn.Conv1d(1,1,kernel_size=w.shape[-1], padding=w.shape[-1]-1)
conv.weight=torch.nn.Parameter(torch.from_numpy(w[0]).float(), requires_grad=False)
conv.bias=torch.nn.Parameter(torch.from_numpy(b).float(), requires_grad=False)

out = conv(torch.from_numpy(x[0]).float())
print('pytorch out (1d):', out.shape, '\n', out[0,0].int().numpy())

conv = torch.nn.Conv2d(1,1,kernel_size=w.shape[-2:], padding=(0,w.shape[-1]-1))
conv.weight=torch.nn.Parameter(torch.from_numpy(w).float(), requires_grad=False)
conv.bias=torch.nn.Parameter(torch.from_numpy(b).float(), requires_grad=False)

out = conv(torch.from_numpy(x).float())
print('pytorch out (2d):', out.shape, '\n', out[0,0,0].int().numpy())

卷积定理
x: [[[[0 1 2 3 4 5 6 7 8 9]]]]
w: [[[[0 1 2 3 4 5 6 7 8 9]]]] b: [0] 

ANN out (scratch): (1, 1, 19, 19) 
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]
np.fft (scratch):
 [  0   8  25  49  79 114 153 196 240 285 240 196 154 114  79  50  26   8
   0]
scipy.fft (scratch):
 [  0   8  25  49  80 114 153 196 240 285 240 196 154 114  79  50  26   8
   0]
np.convolve:
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]
scipy.signal.convolve:
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]
scipy.signal.fftconvolve:
 [  0   9  25  50  80 115 154 196 240 285 240 196 154 114  80  50  26   9
   0]
pytorch out (1d): torch.Size([1, 1, 19]) 
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]
pytorch out (2d): torch.Size([1, 1, 1, 19]) 
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]


>以上是卷积定理在不同的计算逻辑下，实现的相同结果。
>
>同时，还考察了 `torch` 的 `fft` 模块也 `numpy` 的数值精度对应性。

# 卷积 vs 相关

> Ref: 
> 1. [二维kernel信息上的关系](https://www.mathworks.com/help/images/what-is-image-filtering-in-the-spatial-domain.html)
> 2. [Wiki: 各种定义和性质](https://en.wikipedia.org/wiki/Cross-correlation)
> 3. [知名的 Note](http://www.cs.umd.edu/~djacobs/CMSC426/Convolution.pdf)
> 4. [关联的证明](https://math.stackexchange.com/questions/1090974/relation-between-correlation-and-convolution)

- 卷积 + 相关：

  $$
  \begin{align}
  f(t) * g(t) &= \int^{+\infty}_{-\infty}f(\tau)\cdot g(t-\tau)d\tau = \int^{+\infty}_{-\infty}f(t-\tau)\cdot g(\tau)d\tau \\
  f(t) \star g(t) &= \int^{+\infty}_{-\infty}f^*(\tau)\cdot g(t+\tau)d\tau = \int^{+\infty}_{-\infty}f^*(\tau-t)\cdot g(\tau)d\tau
  \end{align}
  $$
  
  定义关联：
  $$
  \begin{align}
  f(t)\star g(t)=f^*(-t) * g(t)
  \end{align}
  $$
  卷积定理关联：
  $$
  \begin{align}
  F\{f*g\}&=F\{f\}\cdot F\{g\}\\
  F\{f\star g\}&=(F\{f\})^*\cdot F\{g\}\\
  \end{align}
  $$



- 卷积 vs 相关

    - np.fft (scratch) (LIGO):
    
    `abs( ifft[ fft[x(t), mod=10] · fft[w(t), mod=10].conjugate() ] )`
    
    - np.correlate + mod-10:
    
    `Mod[ x(t)`$\star$`w(t) , 10](-t)`
    
    - np.convolue + mod-10:
    
    `Mod[ x(t) * w(-t) , 10](-t)`
    
    - mxnet.ndarray.Convolution + mod-10:
    
    `Mod[ x(t) * w(t) , 10](-t)`

    - torch.nn.Con1d/torch.nn.Con2d + mod-10:
    
    `Mod[ x(t) * w(t) , 10](t)`

In [199]:
def mod(out, mod):
#     if type(out) == type(nd.array(1)):
#         return nd.concatenate([out, nd.zeros(out.shape[:-1]+(mod - out.shape[-1]%mod, ), ctx=ctx)], axis=len(out.shape)-1).reshape(0,0,-1,mod).sum(axis=-2).expand_dims(2)[:,:,:,::-1]
    if type(out) == type(torch.tensor(1)):
        return torch.cat((out, torch.zeros(out.shape[:-1]+(mod - out.shape[-1]%mod,))), dim=len(out.shape)-1).reshape(-1,mod).sum(axis=-2)
    elif type(out) == type(np.array(1)):
        return np.concatenate([out, np.zeros(out.shape[:-1]+(mod - out.shape[-1]%mod, ) )], axis=len(out.shape)-1).reshape(-1,mod).sum(axis=-2)[::-1]

In [206]:
print('='*10+'\n卷积 vs 相关\n'+'='*10) ### x w

print('np.fft (scratch):\n',np.abs(np.fft.ifft( np.fft.fft(x[0,0,0]) * np.fft.fft(w[0,0,0]).conjugate() )).astype(np.int32) )
print('np.fft (scratch | full):\n',np.abs(np.fft.ifft( np.fft.fft(x[0,0,0], 9*2-1) * np.fft.fft(w[0,0,0], 9*2-1).conjugate() )).astype(np.int32) )
print('np.correlate:\n', np.correlate(x[0,0,0], w[0,0,0], mode='full') )
print('np.correlate (mod=10):\n', mod(np.correlate(x[0,0,0], w[0,0,0], mode='full') , mod=10).astype(np.int32) )
print('np.convolve:\n', np.convolve(x[0,0,0], w[0,0,0][::-1], mode='full'))
print('np.convolve (mod=10):\n', mod( np.convolve(x[0,0,0], w[0,0,0][::-1], mode='full'), mod=10).astype(np.int32) )

# Mxnet
# out = nd.Convolution(data=nd.array(x), weight=nd.array(w), bias=nd.array(b), kernel=w.shape[-2:],stride=(1,1),
#                                   num_filter=w.shape[1], pad=(0,w.size-1))  # 
# print('mxnet out:',out.shape,'\n',out[0,0,0].asnumpy().astype(np.int32))
# print('mxnet out (mod=10):',out.shape,'\n',mod( out, mod=10)[0,0,0].asnumpy().astype(np.int32) )

# out = nd.Convolution(data=nd.array(x), weight=nd.array(w), bias=nd.array(b), kernel=w.shape[-2:],stride=(1,1),
#                                   num_filter=w.shape[1], pad=(0,w.size))  # 
# print('mxnet out:',out.shape,'\n',out[0,0,0,:-1].reshape(2,-1).sum(axis=0).asnumpy().astype(np.int32))

conv = torch.nn.Conv1d(1,1,kernel_size=w.shape[-1], padding=w.shape[-1])
conv.weight=torch.nn.Parameter(torch.from_numpy(w[0]).float(), requires_grad=False)
conv.bias=torch.nn.Parameter(torch.from_numpy(b).float(), requires_grad=False)

out = conv(torch.from_numpy(x[0]).float())
print('pytorch out (1d):', out.shape, '\n', out[:,:,:-1].reshape(2,-1).sum(axis=0).int().numpy())

print('pytorch out (1d) with mod func:', out.shape, '\n', mod(out, mod=10).int().numpy())

conv = torch.nn.Conv2d(1,1,kernel_size=w.shape[-2:], padding=(0, w.shape[-1]),)
conv.weight=torch.nn.Parameter(torch.from_numpy(w).float(), requires_grad=False)
conv.bias=torch.nn.Parameter(torch.from_numpy(b).float(), requires_grad=False)

out = conv(torch.from_numpy(x).float())
print('pytorch out (2d):', out.shape, '\n', out[:,:,:,:-1].reshape(2,-1).sum(axis=0).int().numpy())

print('pytorch out (2d) with mod func:', out.shape, '\n', mod(out, mod=10).int().numpy())


卷积 vs 相关
np.fft (scratch):
 [285 240 205 180 165 160 165 180 205 240]
np.fft (scratch | full):
 [285 240 196 154 115  79  49  25   8   8  25  49  79 115 154 196 240]
np.correlate:
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]
np.correlate (mod=10):
 [285 240 205 180 165 160 165 180 205 240]
np.convolve:
 [  0   9  26  50  80 115 154 196 240 285 240 196 154 115  80  50  26   9
   0]
np.convolve (mod=10):
 [285 240 205 180 165 160 165 180 205 240]
pytorch out (1d): torch.Size([1, 1, 21]) 
 [285 240 205 180 165 160 165 180 205 240]
pytorch out (1d) with mod func: torch.Size([1, 1, 21]) 
 [285 240 205 180 165 160 165 180 205 240]
pytorch out (2d): torch.Size([1, 1, 1, 21]) 
 [285 240 205 180 165 160 165 180 205 240]
pytorch out (2d) with mod func: torch.Size([1, 1, 1, 21]) 
 [285 240 205 180 165 160 165 180 205 240]


## 模 vs 自相关 (omited)

- 模 vs 自相关 (omited)

In [209]:
# print('模 vs 自相关\n'+'='*10) ### x w

# print('np.correlate:\n', abs(np.fft.fft(np.correlate(x[0,0,0], x[0,0,0], mode='full'))) )
# (abs(np.fft.fft(x[0,0,0], 9*2-1)**2))

# 内积(分母) vs 自相关

- 内积(分母) vs 自相关

    - $<x|x>$
    
    `sum[ fft[ x(t), mod=10] · fft[ x(t), mod=10 ].conjugate() ] x df`
    
    - np.correlate + mod-10:
    
    `Mod[ x(t)`$\star$`w(t) , 10](0) x fs`
    
    - mxnet.ndarray.Convolution + mod-10:
    
    `Mod[ x(t) * w(t) , 10](0) x fs`
    
    - torch.nn.Conv1d/torch.nn.Conv2d + mod-10:
    
    `Mod[ x(t) * w(t) , 10](0) x fs`    

In [222]:
fs = np.random.randint(1, 4096)
datafreq = np.fft.fftfreq(x[0,0,0].size)*fs
df = np.abs(datafreq[1] - datafreq[0])
print('内积(分母) vs 自相关'+'  df: %s\n'% df +'='*20) ### x w

print('分母: \n', ( np.fft.fft(x[0,0,0]) * np.fft.fft(x[0,0,0]).conjugate() ).sum()* df)

print('np.correlate (mod=10):\n', mod( np.correlate(x[0,0,0], x[0,0,0], mode='full') ,mod=10)[0] * fs)

# out = nd.Convolution(data=nd.array(x), weight=nd.array(x), bias=nd.array(b), kernel=w.shape[-2:],stride=(1,1),
#                                   num_filter=w.shape[1], pad=(0,w.size-1))  # 
# print('mxnet out (mod=10):',out.shape,'\n',mod( out, mod=10)[0,0,0,0].asnumpy().astype(np.int32) * fs)

conv = torch.nn.Conv1d(1,1,kernel_size=x.shape[-1], padding=x.shape[-1])
conv.weight=torch.nn.Parameter(torch.from_numpy(x[0]).float(), requires_grad=False)
conv.bias=torch.nn.Parameter(torch.from_numpy(b).float(), requires_grad=False)

out = conv(torch.from_numpy(x[0]).float())
print('pytorch out (1d) (mod=10):',out.shape,'\n',mod( out, mod=10)[0].int().numpy() * fs)

conv = torch.nn.Conv2d(1,1,kernel_size=x.shape[-2:], padding=(0,x.shape[-1]))
conv.weight=torch.nn.Parameter(torch.from_numpy(x).float(), requires_grad=False)
conv.bias=torch.nn.Parameter(torch.from_numpy(b).float(), requires_grad=False)

out = conv(torch.from_numpy(x).float())
print('pytorch out (2d) (mod=10):',out.shape,'\n',mod( out, mod=10)[0].int().numpy() * fs)

内积(分母) vs 自相关  df: 255.9
分母: 
 (729315+0j)
np.correlate (mod=10):
 729315.0
pytorch out (1d) (mod=10): torch.Size([1, 1, 21]) 
 729315
pytorch out (2d) (mod=10): torch.Size([1, 1, 1, 21]) 
 729315


# MFCNN with Pytorch

In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# from torch.nn import init

In [2]:
torch.__version__, torch.cuda.is_available()

('1.0.0.dev20190402', True)

In [3]:
def try_all_gpus():
    """Return all available GPUs, or [cpu(),] if no GPU exists."""
    devices = [torch.device(f'cuda:{i}')
             for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]     
devices = try_all_gpus()
devices

[device(type='cuda', index=0),
 device(type='cuda', index=1),
 device(type='cuda', index=2),
 device(type='cuda', index=3)]

In [4]:
train_data = np.load('train_data_z3_random_AE_with_gb_16384.npy') # (6000, 2, 16384) [17384 * 15]sec
train_label = np.zeros((6000,)) 
train_label[3000:] = 1

train_data = torch.tensor(train_data, dtype=torch.float32)
# train_label = torch.tensor(pd.get_dummies(train_label).values, dtype=torch.float32)
train_label = torch.tensor(train_label, dtype=torch.long)

In [5]:
batch_size = 128
dataset = TensorDataset(train_data, train_label)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, 
                          shuffle=True, #pin_memory=True,
#                           num_workers=16,
                          worker_init_fn=lambda _: np.random.seed(
                              int(torch.initial_seed()) % (2**32-1)))

In [6]:
S_t_m12 = np.load('S_t.npy')[::-1]
S_t_m12 = torch.tensor(abs(S_t_m12), dtype=torch.float32).reshape(1, 1, 16384)
# S_t_m12 = torch.cat((S_t_m12, S_t_m12), 1)  # [2, 1, 16384]
print(S_t_m12.shape)

template = np.load('template_St_matrix_z3_AE_4096_50.npy')
template = torch.tensor(template, dtype=torch.float32).reshape(-1, 2, 4096) #(50, 2, 4096) [4096 x 15]sec
print(template.shape)

torch.Size([1, 1, 16384])
torch.Size([50, 2, 4096])


In [7]:
hh_sqrt = np.load('hh_sqrt_4096_50.npy') # (50, 2) 
hh_sqrt = torch.tensor(hh_sqrt, dtype=torch.float32)
hh_sqrt.shape

torch.Size([50, 2])

In [8]:
class MFLayer(nn.Module):
    def __init__(self,template, hh_sqrt, S_t_m12):
        super(MFLayer, self).__init__() 
        self.data_size = S_t_m12.shape[-1] # 16384
        self.temp_size = template.shape[-1] # 4096
        self.params = nn.ParameterDict({
                'template': nn.Parameter(template, requires_grad=False),
                'hh_sqrt': nn.Parameter(hh_sqrt.unsqueeze(0).unsqueeze(-1), requires_grad=False), 
                'S_t_m12': nn.Parameter(S_t_m12, requires_grad=False),
        })

    def _mod(self, X, mod):
        return F.pad(X, pad=(0, (-X.shape[-1]) % mod)).unsqueeze(-2).reshape(X.shape[0],-1, abs((-X.shape[-1]) // mod), mod).sum(-2)

    def forward(self, X):
        # split A & E
        xa = X[:,:1]
        xe = X[:,1:]
        
        # d / sqrt(S)
        d_SA = self._mod(F.conv1d(xa, self.params['S_t_m12'], padding=self.data_size-1, groups=1), mod=self.data_size)
        d_SE = self._mod(F.conv1d(xe, self.params['S_t_m12'], padding=self.data_size-1, groups=1), mod=self.data_size)
        # [num_batch, 1, self.data_size]
        
        h_SA = self.params['template'][:,:1]
        h_SE = self.params['template'][:,1:]
        # [num_temp, 1, self.temp_size]
        
        # <d|h>   
        dh_A = self._mod(F.conv1d(d_SA, h_SA, padding=self.temp_size-1, groups=1), mod=self.data_size)
        dh_E = self._mod(F.conv1d(d_SE, h_SE, padding=self.temp_size-1, groups=1), mod=self.data_size)
        
        # [num_batch, num_temp, 2, self.data_size]
        return torch.stack((dh_A, dh_E), -2) / self.params['hh_sqrt']

class CutHybridLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, X):
        return torch.max(torch.abs(X), -1).values.permute(0,2,1)

class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, X):
        return torch.flatten(X, start_dim=1)

net = nn.Sequential(
    MFLayer(template, hh_sqrt, S_t_m12), 
    CutHybridLayer(),
    nn.Conv1d(in_channels=2, out_channels=16, kernel_size=3, stride=1),
    nn.ReLU(), 
    nn.MaxPool1d(kernel_size=4, stride=2),
    nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1),
    nn.ReLU(), 
    nn.MaxPool1d(kernel_size=4, stride=2),
#     nn.Flatten(),
    Flatten(),
    nn.Linear(288, 32),
    nn.ReLU(), 
    nn.Dropout(p=0.5),
    nn.Linear(32, 2),
    nn.Sigmoid()
)

print(net)

Sequential(
  (0): MFLayer(
    (params): ParameterDict(
        (S_t_m12): Parameter containing: [torch.FloatTensor of size 1x1x16384]
        (hh_sqrt): Parameter containing: [torch.FloatTensor of size 1x50x2x1]
        (template): Parameter containing: [torch.FloatTensor of size 50x2x4096]
    )
  )
  (1): CutHybridLayer()
  (2): Conv1d(2, 16, kernel_size=(3,), stride=(1,))
  (3): ReLU()
  (4): MaxPool1d(kernel_size=4, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv1d(16, 32, kernel_size=(3,), stride=(1,))
  (6): ReLU()
  (7): MaxPool1d(kernel_size=4, stride=2, padding=0, dilation=1, ceil_mode=False)
  (8): Flatten()
  (9): Linear(in_features=288, out_features=32, bias=True)
  (10): ReLU()
  (11): Dropout(p=0.5)
  (12): Linear(in_features=32, out_features=2, bias=True)
  (13): Sigmoid()
)


In [16]:
# output = net(train_data[:10])
# output.shape

In [13]:
# net = nn.DataParallel(net, device_ids=devices)#.to(devices[0])
net = net.cuda()

In [14]:
lr, wd = 1e-3, 1e-4
lr_period, lr_decay = 2, 0.9
num_epochs = 20

trainer = torch.optim.Adam((param for param in net.parameters() if param.requires_grad), 
                           lr=lr, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
num_batches = len(train_loader)
loss = nn.CrossEntropyLoss(reduction="none")

In [15]:
for epoch in range(num_epochs):
    for i, (features, labels) in enumerate(train_loader):
        features, labels = features.cuda(), labels.cuda()
        trainer.zero_grad()
        output = net(features)
        l = loss(output, labels).mean()
        l.backward()
        trainer.step()
        print(l)
    scheduler.step()

tensor(0.6684, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6072, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6251, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5946, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6126, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5958, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5971, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5924, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5706, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.6000, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5685, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5689, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5963, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5734, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5721, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5849, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5532, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.5893, device='cuda:0',

KeyboardInterrupt: 