In [None]:
from __future__ import print_function
import numpy as np
import torch
import time
import matplotlib.pyplot as plt

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.optim.lr_scheduler import MultiStepLR
from sklearn.metrics import mean_squared_error

# for vanilla VQ, uncomment next line and comment out next next line
# from model.model_v2 import spk_vq_vae_resnet
from model.model_v2_EWA import spk_vq_vae_resnet
from model.utils import Helper

gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# %% parameters passed to data loader
import collections
args = collections.namedtuple

# args.train_data_path = './data/wave_clus_noise01_c4.mat'
# args.test_data_path = './data/wave_clus_noise01_c4.mat'

# args.train_data_path = './data/hc1.mat'
# args.test_data_path = './data/hc1.mat'

# args.train_data_path = './data/C_drift.mat'
# args.test_data_path = './data/C_drift.mat'

# args.train_data_path = './data/pac_scream70dB_spks_3d.mat'
# args.test_data_path = './data/pac_scream70dB_spks_3d.mat'

args.train_data_path = '../DeepVAE_data/C_difficult1_spks_c4.mat'
args.test_data_path = '../DeepVAE_data/C_difficult1_spks_c4.mat'

# args.train_data_path = './data/C_easy1_spks_c4.mat'
# args.test_data_path = './data/C_easy1_spks_c4.mat'

# args.train_data_path = '../DeepVAE_data/Neuropixels_spks_c4.mat'
# args.test_data_path = '../DeepVAE_data/Neuropixels_spks_c4.mat'

args.train_ratio = .6
args.test_ratio = .4
args.seed = 1
args.batch_size = 48
args.test_batch_size = 128
args.randperm = True

if args.train_data_path is not args.test_data_path:
    args.val_norm = True
else:
    args.val_norm = False

# %% global parameters
spk_ch = 4
args.spk_ch = spk_ch
spk_dim = 64
log_interval = 50
beta = 0.15
vq_num = 128

"""
org_dim     = param[0]
conv1_ch    = param[1]
conv2_ch    = param[2]
conv1_ker   = param[4]
conv2_ker   = param[5]
self.vq_dim = param[6]
self.vq_num = param[7]
cardinality = param[8]
vq on chan  = param[9]
"""
param_resnet_v2 = [spk_ch, 256, 16, 1, 3, 1, int(spk_dim/4), vq_num, 16, 4, False, 0.2]

# %% train/test splitting and normalization
helper = Helper(args)
train_loader, test_loader = helper.create_data_loaders()

In [None]:
# %%
model = spk_vq_vae_resnet(param_resnet_v2).to(gpu)

In [None]:
# %%
def loss_function(recon_x, x, commit_loss):
    recon_loss = F.mse_loss(recon_x, x, size_average=False)
    return recon_loss + beta * commit_loss, recon_loss
    #return recon_loss + vq_loss, recon_loss

In [None]:
# %%
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4, amsgrad=True)
#optimizer = optim.SGD(model.parameters(), lr=1e-6, weight_decay=1e-5, momentum=0.9)
#optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-6)
#optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=.9)

# decay_embed, decay_rest = [], []
# for name, param in model.named_parameters():
#     if 'embed' in name:
#         decay_embed.append(param)
#     else:
#         decay_rest.append(param)

# optimizer = optim.SGD([
#     {'params': decay_rest, 'weight_decay':1e-4},
#     {'params': decay_embed, 'weight_decay':1e-5, 'momentum':0.9}
#     ], lr=1e-9, momentum=0.9)

# scheduler = MultiStepLR(optimizer, milestones=[250,350,600], gamma=0.1)

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        #data = Variable(data).cuda()
        data = data.to(gpu)

        optimizer.zero_grad()
        recon_batch, commit_loss = model(data)
        loss, recon_loss = loss_function(recon_batch, data, commit_loss)
        loss.backward(retain_graph=True)
        model.bwd()
        optimizer.step()
        
        #train_loss += recon_loss.data[0] * len(data)
        train_loss += recon_loss.item() / (spk_dim * spk_ch)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), recon_loss.item()))

    #print(model.embed.weight.data)
    
    average_train_loss = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average train loss: {:.5f}'.format(
          epoch, average_train_loss))
    return average_train_loss

In [None]:
# %% save model
best_val_loss = 10 # v3
cur_train_loss = 1
#best_val_loss = .3965 # v1
def save_model(val_loss, train_loss):
	global best_val_loss, cur_train_loss
	if val_loss < best_val_loss:
		best_val_loss = val_loss
		cur_train_loss = train_loss
		torch.save(model.state_dict(), './spk_vq_vae_temp.pt')

In [None]:
# %%
def test(epoch):
    model.eval()
    model.embed_reset()
    test_loss = 0
    recon_sig = torch.rand(1, spk_ch, spk_dim)
    org_sig = torch.rand(1, spk_ch, spk_dim)
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            data = data.to(gpu)

            recon_batch, commit_loss = model(data)
            _, recon_loss = loss_function(recon_batch, data, commit_loss)
        
            recon_sig = torch.cat((recon_sig, recon_batch.data.cpu()), dim=0)
            org_sig = torch.cat((org_sig, data.data.cpu()), dim=0)
        
            #test_loss += recon_loss.data[0] * len(data)
            test_loss += recon_loss.item() / (spk_dim * spk_ch)

        average_test_loss = test_loss / len(test_loader.dataset)
        print('====> Epoch: {} Average test loss: {:.5f}'.format(
              epoch, average_test_loss))

    if epoch % 10 == 0:
        plt.figure(figsize=(15,5))
        plt.bar(np.arange(vq_num), model.embed_freq / model.embed_freq.sum())
        plt.show()

    return average_test_loss, recon_sig[1:], org_sig[1:]

In [None]:
# %% training and validating
train_loss_history = []
test_loss_history = []

In [None]:
epochs = 500
start_time = time.time()

for epoch in range(1, epochs + 1):
    #scheduler.step()
    
    train_loss = train(epoch)
    test_loss, _, _ = test(epoch)
    #save_model(test_loss, train_loss)
    
    train_loss_history.append(train_loss)
    test_loss_history.append(test_loss)
    
print("--- %s seconds ---" % (time.time() - start_time))
print('Minimal train/testing losses are {:.4f} and {:.4f} with index {}\n'
    .format(cur_train_loss, best_val_loss, test_loss_history.index(min(test_loss_history))))

# plot train and test loss against epochs
plt.figure(1)
epoch_axis = range(1, len(train_loss_history) + 1)
plt.plot(epoch_axis, train_loss_history, 'bo')
plt.plot(epoch_axis, test_loss_history, 'b+')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
# plt.figure()
# n_row, n_col = 8, 8
# f, axarr = plt.subplots(n_row, n_col, figsize=(n_col*1.5, n_row*1.5))
# cur_ker = model.embed.weight.data.cpu().numpy()
# for i in range(n_row):
#     for j in range(n_col):
#         axarr[i, j].plot(cur_ker[i*n_row+j], 'r')
#         axarr[i, j].axis('off')
# plt.show()

#model.load_state_dict(torch.load('./spk_vq_vae_temp.pt'))
#torch.save(model.state_dict(), './cae_models/spk_vq_vae_hc1_vq{}_N{}.pt'.format(vq_num, param_resnet_v2[2]))
#torch.save(model.state_dict(), './cae_models/spk_vq_vae_neuropixels_c15r_vq{}_N{}.pt'.format(vq_num, param_resnet_v2[2]))
#torch.save(model.state_dict(), './cae_models/spk_vq_vae_waveclus_vq{}_N{}.pt'.format(vq_num, param_resnet_v2[2]))

embed_idx = np.argsort(model.embed_freq)
embed_sort = model.dict_val.data.cpu().numpy()[embed_idx]

plt.figure()
n_row, n_col = 1, 8
f, axarr = plt.subplots(n_row, n_col, figsize=(n_col*2, n_row*2))
for i in range(8):
    axarr[i].plot(embed_sort[i], 'r')
    axarr[i].axis('off')
plt.show()

In [None]:
# %% spike recon
train_mean, train_std, _ = helper.param_for_recon()
train_mean, train_std = torch.from_numpy(train_mean), torch.from_numpy(train_std)
_, val_spks, test_spks = test(9)

# val_spks = val_spks.numpy()
# test_spks = test_spks.numpy()

# calculate compression ratio
vq_freq = model.embed_freq / sum(model.embed_freq)
vq_freq = vq_freq[vq_freq != 0]
vq_log2 = np.log2(vq_freq)
bits = -sum(np.multiply(vq_freq, vq_log2))
cr = spk_ch * spk_dim * 16 / (param_resnet_v2[2] * bits)
print('compression ratio is {:.2f} with {:.2f}-bit.'.format(cr, bits))

In [None]:
recon_spks = (val_spks * train_std + train_mean).view(-1, spk_dim)
test_spks_v2 = (test_spks * train_std + train_mean).view(-1, spk_dim)

recon_err = torch.norm(recon_spks-test_spks_v2, p=2, dim=1) / torch.norm(test_spks_v2, p=2, dim=1)

print('mean of recon_err is {:.4f}'.format(torch.mean(recon_err)))
print('std of recon_err is {:.4f}'.format(torch.std(recon_err)))

In [None]:
# %% spike visualization for 3-d inputs
recon_spks_new = recon_spks.numpy()
test_spks_new = test_spks_v2.numpy()

def cal_sndr(org_data, recon_data):
    org_norm = np.linalg.norm(org_data, axis=1)
    err_norm = np.linalg.norm(org_data-recon_data, axis=1)
    return np.mean(20*np.log10(org_norm / err_norm)), np.std(20*np.log10(org_norm / err_norm))

cur_sndr, sndr_std = cal_sndr(test_spks_new, recon_spks_new)
print('SNDR is {:.4f} with std {:.4f}'.format(cur_sndr, sndr_std))

In [None]:
rand_val_idx = np.random.permutation(len(recon_spks_new))

plt.figure()
n_row, n_col = 3, 8
spks_to_show = test_spks_new[rand_val_idx[:n_row*n_col]]
ymax, ymin = np.amax(spks_to_show), np.amin(spks_to_show)
f, axarr = plt.subplots(n_row, n_col, figsize=(n_col*3, n_row*3))
for i in range(n_row):
    for j in range(n_col):
        axarr[i, j].plot(recon_spks_new[rand_val_idx[i*n_col+j]], 'r')
        axarr[i, j].plot(test_spks_new[rand_val_idx[i*n_col+j]], 'b')
        axarr[i, j].set_ylim([ymin*1.1, ymax*1.1])
        axarr[i, j].axis('off')
#plt.show()
#plt.savefig('waveclus_spks.eps', format='eps', dpi=600)