In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import scipy.signal
import scipy.stats
import scipy.integrate 
import json

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'damselfly/data/datasets')
SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')
TRAINPATH = os.path.join(PATH, 'damselfly/training')

"""
Date: 6/25/2021
Description: template
"""

"""
config = {

    conv: {
    
        0: { # the template convolution layer
            'in': ,
            'out': ,
            'kernel': ,
            
        },

        1: {
            'in':[],
            'out':[],
            'kernel':[],
            'dilation':[],
            'maxpool_kernel': ,
        },
        
        2: {...},
        
        }, 
        
    linear: {...},
    
    conv_template_tensor: [...],
    
    frozen_template_tensor: [...],
    
    
}
"""


class DFHybridInit(torch.nn.Module):
    
    def __init__(self, config):
        super(DFHybridInit, self).__init__()
        
        
        self.mf_conv = MFMax(config['conv_template_tensor'])

        self.conv = df.models.ConvStack1D(conv_list)
        
        self.mf_conv = MFConv(template_tensor)
        
        self.mf_max = MFMax(template_tensor.shape[-1])
        
        self.mf_norm = torch.nn.BatchNorm1d(template_tensor.shape[0], affine=False)
        
        self.conv_norm = torch.nn.BatchNorm1d(linear_list[0][0] - template_tensor.shape[0], affine=False)
        
        self.linear = df.models.StackLinear(linear_list[0], linear_list[1], linear_list[2])
        
        self.linear_out = torch.nn.Linear(linear_list[1][-1], nclass)

    def NumFlatFeatures(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
        
    def forward(self, x):
        
        y = self.mf_conv(x)
        y = self.mf_max(torch.abs(y))
        #print(y.shape)
        y = self.mf_norm(y)
        y = y.squeeze()
        
        x = self.conv(x)
        #print(x.shape)
        x = x.view(-1, self.NumFlatFeatures(x))
        x = self.conv_norm(x)
        
        
        z = torch.cat((x, y), dim=-1)
        
        z = self.linear(z)
        z = self.linear_out(z)
        
        return z
    
    
def MFConv(template_tensor, requires_grad=False):
    
    mf_conv = torch.nn.Conv1d(
        template_tensor.shape[1], 
        template_tensor.shape[0], 
        template_tensor.shape[-1], 
        padding='same',
        dilation=1, 
        groups=1, 
        bias=False,
        padding_mode='circular', 
    )
    
    mf_conv.weight = torch.nn.Parameter(template_tensor, requires_grad=requires_grad)
    
    return mf_conv
    
def MFMax(filter_size):
    
    #conv_max = torch.amax(mf_max(torch.abs(convolution)), 1, keepdim=True)
    
    return torch.nn.MaxPool1d(filter_size)
    
    
    

In [None]:
x = torch.rand((10, 2, 8192))

In [None]:
mf_conv = MFConv(x)


In [None]:
mf_conv

In [None]:
x = torch.rand((10, 2, 8192))


nch = 2
nslice = 1
input_shape_1d = 8192
conv_list_1d = [
        [
            [nch * nslice, 4, 4],
            [4, 4,4],
            [16, 16, 16],
            [1, 1, 1], # dilation
            16
        ],
        [
            [4, 6, 6],
            [6, 6, 6],
            [8, 8, 8],
            [1, 1, 1],
            8
        ],
        [
            [6, 8, 8],
            [8, 8, 8],
            [4, 4, 4],
            [1, 1, 1],
            4
        ],
    ]

model_config_1d_cnn = {
    'nclass': 2,
    'nch': 2,
    'conv': conv_list_1d
    }

linear_list_1d = [
        [df.models.GetConv1DOutputSize(model_config_1d_cnn['conv'], model_config_1d_cnn['nch'], input_shape_1d) + 10, 100],
        [100, 20],
        [0.0, 0.0]
    ]

model = DFHybrid(model_config_1d_cnn['nclass'], 
        model_config_1d_cnn['nch'], 
        model_config_1d_cnn['conv'], 
        linear_list_1d,
                 x
                )

In [None]:
y = torch.rand(5, 2, 8192)

print(model(y).shape)

In [None]:
y = MFConv(x)
print(y.shape)

In [None]:
n_pc = 256

conv_list = [
    [
        [2, 20, 20],
        [20, 20, 20],
        [(n_pc, 4), (n_pc, 4), (n_pc, 4)],
        (1, 4)
    ],
    [
        [20, 40, 40],
        [40, 40, 40],
        [(n_pc, 4), (n_pc, 4), (n_pc, 4)],
        (1, 4)
    ],
    [
        [40, 80, 80],
        [80, 80, 80],
        [(n_pc, 4), (n_pc, 4), (n_pc, 4)],
        (1, 4)
    ],
]

in_ch = 2
input_shape = (256, 128)

GetConv2DOutputSize(conv_list, in_ch, input_shape)



In [None]:
x = torch.rand((1, 2, 60, 8192))

layer2d = torch.nn.Conv2d(2, 20, (60, 12), padding='same')

maxlayer = torch.nn.MaxPool2d((1, 16))

In [None]:
layer2d(x).shape

In [None]:
maxlayer(layer2d(x)).shape