In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from dataloader_cage import dataformat_for_rnn
import torch.optim as optim
import time
import numpy as np

class causal_conv(nn.Module):
    def __init__(self, cut):
        super(causal_conv, self).__init__()
        self.cut = cut

    def forward(self, x):
        return x[:, :, :-self.cut].contiguous()


class conv_block(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0):
        super(conv_block, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.causal_conv1 = causal_conv(padding)
        self.activate1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.causal_conv2 = causal_conv(padding)
        self.activate2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.causal_conv1, self.activate1, self.dropout1,
                                 self.conv2, self.causal_conv2, self.activate2, self.dropout2)
        if n_inputs != n_outputs:
            self.downsample = nn.Conv1d(n_inputs, n_outputs, 1)
        else:
            self.downsample = None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        if self.downsample is None:
            res = x
        else:
            res = self.downsample(x)
        return self.relu(out + res)


class temp_conv_net(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=4, dropout=0):
        super(temp_conv_net, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            if i == 0:
                in_channels = num_inputs
            else:
                in_channels = num_channels[i-1]
            out_channels = num_channels[i]
            layers += [conv_block(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)
    
class pack_temp_conv_net(nn.Module):
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super(pack_temp_conv_net, self).__init__()
        self.tcn = temp_conv_net(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)

    def forward(self, inputs):
        y1 = self.tcn(inputs)
        o = self.linear(y1[:, :, -1])
        return torch.squeeze(o)

In [65]:
def train_conv_decoder(train_x, train_y, lr, n_channels, filter_size, dropout, epoch, batchSize, checkPoint, use_cuda = False):
    lossList = []
    # build up data loader, train_x and train_y should be the format for the network
    dataset = dataformat_for_rnn(train_x, train_y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=True, sampler=None,
                                             batch_sampler=None)
    net = None
    D_input, D_output = np.size(train_x, 1), np.size(train_y, 1)
    
    net = pack_temp_conv_net(D_input, D_output, n_channels, filter_size, dropout)
    
    if use_cuda:
        net = net.cuda()
    net = net.train()
    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    t1 = time.time()
    lossSum = 0

    print("Data loader num:", len(dataloader))

    for i in range(epoch):
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.type('torch.FloatTensor'), y.type('torch.FloatTensor')
            if use_cuda:
                x = x.cuda()
                y = y.cuda()

            optimizer.zero_grad()
            pred = net.forward(x)
            loss = criterion(pred, y)
            lossSum += loss.item()

            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                if batch_idx % checkPoint == 0 and batch_idx != 0:
                   print("batch: %d , loss is:%f" % (batch_idx, lossSum / checkPoint))
                   lossList.append(lossSum / checkPoint)
                   lossSum = 0

        if i % 10 == 0:
            print("%d epoch is finished!" % (i+1))

    t2 = time.time()
    print("train time:", t2-t1)
    return net


def test_conv_decoder(net, test_x, use_cuda=False):
    net = net.eval()
    with torch.no_grad():
        test_x = torch.from_numpy(test_x).type('torch.FloatTensor')
        if use_cuda:
            test_x = test_x.cuda()
        pred = net(test_x)
        if use_cuda:
            pred = pred.cpu()
    return pred.data.numpy()    

In [61]:
import numpy as np
import fnmatch, os
from xds import lab_data, list_to_nparray, smooth_binned_spikes
from IPython.display import clear_output

base_path = '../lab_data/Greyson_WM_2D/'
file_list = fnmatch.filter(os.listdir(base_path), "*.mat")
file_list = np.sort(file_list)
print(file_list)
#%%
from util import fix_bad_array, find_EMG_idx
bin_size = 0.05
bad_chs = []
bad_chs = [9, 19, 29, 39, 49, 59, 69, 1, 11, 21, 31, 41, 61, 71, 2, 12, 22, 32, 3, 13, 4, 14, 24, 26, 20, 40]
EMG_list = ['EMG_FCR', 'EMG_FDS1', 'EMG_ECR', 'EMG_ECU']

file_name = '20191218_Greyson_WM_002.mat'
dataset = lab_data(base_path, file_name)
dataset.update_bin_data(bin_size)
idx_s, idx_e = fix_bad_array(dataset, bad_chs), find_EMG_idx(dataset, EMG_list)
train_spike, train_emg = dataset.spike_counts[:, idx_s], dataset.EMG#[:, idx_e]

#file_name = "20190815_Greyson_Key_002.mat"
#file_name = "Jango_20140725_IsoHandleHoriz_Utah10ImpEMGs_SN_001.mat"
file_name = '20191218_Greyson_WM_003.mat'
dataset = lab_data(base_path, file_name)
dataset.update_bin_data(0.05)
idx_s, idx_e = fix_bad_array(dataset, bad_chs), find_EMG_idx(dataset, EMG_list)
test_spike, test_emg = dataset.spike_counts[:, idx_s], dataset.EMG#[:, idx_e]
clear_output()

In [62]:
"""
Linear decoder
"""
from wiener_filter import dataset_for_WF_multifile
from wiener_filter import wiener_cascade_train, wiener_cascade_test, wiener_only_train,w_filter_test
from util import vaf

n_lags = 10
train_x_wiener, train_y_wiener = dataset_for_WF_multifile(train_spike, train_emg, n_lags)
print(np.size(train_x_wiener, 0))
test_x_wiener, test_y_wiener = dataset_for_WF_multifile(test_spike, test_emg, n_lags)
print(np.size(test_x_wiener, 0))

H_reg, res_lsq = wiener_cascade_train(train_x_wiener, train_y_wiener, l2 = 0)
test_y_wiener_pred = wiener_cascade_test(test_x_wiener, H_reg, res_lsq)
print('The vaf of Wiener cascade decoder is: %.3f' % vaf(test_y_wiener, test_y_wiener_pred))

H_reg = wiener_only_train(train_x_wiener, train_y_wiener, l2 = 0)
test_y_wiener_pred = w_filter_test(test_x_wiener, H_reg)
print('The vaf of Wiener decoder (linear) is: %.3f' % vaf(test_y_wiener, test_y_wiener_pred))

17984
17984
The vaf of Wiener cascade decoder is: 0.592
The vaf of Wiener decoder (linear) is: 0.554


In [74]:
from dataloader_cage import create_samples_xy_rnn_list
n_lags = 18
train_x, train_y = create_samples_xy_rnn_list(train_spike, train_emg, n_lags, 1)
test_x, test_y = create_samples_xy_rnn_list(test_spike, test_emg, n_lags, 1)
n_channels = [np.size(train_x,1), np.size(train_x,1), np.size(train_x,1), np.size(train_x,1)]
filter_size = 4
dropout = 0.05
TCN_decoder = train_conv_decoder(train_x, train_y, 0.001, n_channels, filter_size, dropout, 
                                 epoch = 40, batchSize = 128, checkPoint = 50, use_cuda = True)
pred_y = test_conv_decoder(TCN_decoder, test_x, True)
print('The vaf of CNN based decoder is: %.3f' % vaf(test_y, pred_y))

17994
17994
Data loader num: 141
batch: 50 , loss is:7.421503
batch: 100 , loss is:4.873514
1 epoch is finished!
batch: 50 , loss is:73.489983
batch: 100 , loss is:2.151832
11 epoch is finished!
batch: 50 , loss is:52.859498
batch: 100 , loss is:1.747653
21 epoch is finished!
batch: 50 , loss is:42.880444
batch: 100 , loss is:1.523186
31 epoch is finished!
train time: 65.67585873603821
The vaf of CNN based decoder is: 0.673


In [78]:
from decoder_rnn import train_RNN_decoder, predict_RNN_decoder
from dataloader_cage import create_samples_xy_rnn_list
n_lags = 18
train_x, train_y = create_samples_xy_rnn_list(train_spike, train_emg, n_lags, 0)
test_x, test_y = create_samples_xy_rnn_list(test_spike, test_emg, n_lags, 0)
RNN_decoder = train_RNN_decoder(train_x, train_y, n_lags, 0.001, 'LSTM', hidden_num = 100, n_layer = 1, 
                      epoch = 40, batchSize = 128, checkPoint = 50, use_cuda = True)
test_y_rnn = predict_RNN_decoder(RNN_decoder, test_x, True)
print(vaf(test_y, test_y_rnn))

17994
17994
LSTM(70, 100, batch_first=True)
Data loader num: 141
batch: 50 , loss is:10.441832
batch: 100 , loss is:7.039847
1 epoch is finished!
batch: 50 , loss is:76.908660
batch: 100 , loss is:2.040569
11 epoch is finished!
batch: 50 , loss is:46.637114
batch: 100 , loss is:1.550000
21 epoch is finished!
batch: 50 , loss is:37.623992
batch: 100 , loss is:1.272443
31 epoch is finished!
train time: 18.412959337234497
0.6386242883360738
