In [15]:
import os, sys
os.chdir('/home/seigyo/Documents/pytorch/brain_decoder')
sys.path.append(os.pardir)
import numpy as np
from numpy.random import RandomState
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import mne
from mne.io import concatenate_raws
from mymodule.utils import data_loader, evaluator
from mymodule.layers import LSTM, Residual_block, Res_net, Wavelet_cnn, NlayersSeqConvLSTM
from mymodule.trainer import Trainer
from mymodule.optim import Eve, YFOptimizer
from sklearn.utils import shuffle
from tensorboardX import SummaryWriter
from load_data import get_data, get_data_multi, get_crops, get_crops_multi
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

class Wavelet_cnn(nn.Module):
    '''
    input size is (batch, 1, seq, electrode)
    seq is expected 2^n
    output[0]: decomposition: size is (batch, 1, seq, electrode)
    output[1]: scalegram: size is (batch, level, seq, electrode)
    output[2]: list of pure coef
    '''
    def __init__(self, conv_length=32, stride=2, level=100):
        super(Wavelet_cnn, self).__init__()
        pad = conv_length - 1
        self.pad = nn.ZeroPad2d((0, 0, pad, 0))
        self.conv_length = conv_length
        self.stride = stride
        self.level = level
        ### l_pass, h_pass
        self.l_conv = nn.Conv2d(1, 1, (conv_length, 1),
                                stride=(1, 1), bias=False)
        self.h_conv = nn.Conv2d(1, 1, (conv_length, 1),
                                stride=(1, 1), bias=False)
        ### down sampler
        self.l_downsample = nn.Conv2d(1, 1, (1, 1),
                                      stride=(stride, 1), bias=False)
        self.h_downsample = nn.Conv2d(1, 1, (1, 1),
                                      stride=(stride, 1), bias=False)
        self.init_filter()


    def init_filter(self):
        self.l_conv.weight.data.fill_(1.0*(1/self.l_conv.weight.size(2)))
        self.h_conv.weight.data.fill_(1.0*(1/self.h_conv.weight.size(2)))
        for i in range(0, self.conv_length, 2):
            self.h_conv.weight.data[:,:,i,:] *= -1

    def forward(self, x):
#         x = x.unsqueeze(1).unsqueeze(-1)
        b = x
        coefs = []
        batch_size = x.size(0)
        for i in range(self.level):
            num_elec = b.size(3)
            seq_half = int(b.size(2)/2)
            # a = F.adaptive_avg_pool2d(self.h_conv(self.pad(b)),
            #                           (seq_half, num_elec))
            # b = F.adaptive_avg_pool2d(self.l_conv(self.pad(b)),
            #                           (seq_half, num_elec))
            a = self.h_downsample(self.h_conv(self.pad(b)))
            b = self.l_downsample(self.l_conv(self.pad(b)))
            coefs.append(a.view(batch_size,-1))
            coefs.append(b.view(batch_size,-1))
            
#             torch.cat((a.view(batch_size,-1), b.view(batch_size,-1)), dim=1)

            if b.size(2) < self.stride:
                break
#         print('composition level is {}'.format(i+1))
        return coefs

In [16]:
model = Wavelet_cnn().cuda()

In [17]:
X = Variable(torch.randn((32, 1, 160, 64)).cuda())
Y = model(X)

In [18]:
Y[6].size()

torch.Size([32, 640])

In [None]:
torch.cat((x1,x2), dim=1).size()