In [None]:
from __future__ import print_function
import numpy as np
import time
import matplotlib.pyplot as plt
import line_profiler
import scipy.io as sio
import math
import collections

import torch
from torch import optim
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, BatchSampler
from sklearn.metrics import mean_squared_error

from model.model_v2 import spk_vq_vae_resnet
from model.utils import SpikeDataset

gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Parameter Configuration

In [None]:
# %% global parameters
spk_ch = 4
spk_dim = 64 # for Wave_Clus
# spk_dim = 48 # for HC1 and Neuropixels
log_interval = 10
beta = 0.15
vq_num = 128
batch_size = 48
test_batch_size = 1000

"""
org_dim     = param[0]
conv1_ch    = param[1]
conv2_ch    = param[2]
conv0_ker   = param[3]
conv1_ker   = param[4]
conv2_ker   = param[5]
self.vq_dim = param[6]
self.vq_num = param[7]
cardinality = param[8]
dropRate    = param[9]
"""
param_resnet_v2 = [spk_ch, 256, 16, 1, 3, 1, int(spk_dim/4), vq_num, 32, 0.2]

## Preparing data loaders

In [None]:
noise_file = './data/noisy_spks.mat'
clean_file = './data/clean_spks.mat'

args = collections.namedtuple

# training set purposely distorted to train denoising autoencoder
args.data_path = noise_file
args.train_portion = .5
args.train_mode = True
train_noise = SpikeDataset(args)

# clean dataset for training
args.data_path = clean_file
args.train_portion = .5
args.train_mode = True
train_clean = SpikeDataset(args)

# noisy datast for training
args.data_path = noise_file
args.train_portion = .5
args.train_mode = False
test_noise = SpikeDataset(args)

# clean dataset for testing
args.data_path = clean_file
args.train_portion = .5
args.train_mode = False
test_clean = SpikeDataset(args)

batch_cnt = int(math.ceil(len(train_noise) / batch_size))

# normalization
d_mean, d_std = train_clean.get_normalizer()

train_clean.apply_norm(d_mean, d_std)
train_noise.apply_norm(d_mean, d_std)
test_clean.apply_norm(d_mean, d_std)
test_noise.apply_norm(d_mean, d_std)

## Model definition

In [None]:
# %% create model
model = spk_vq_vae_resnet(param_resnet_v2).to(gpu)

# %% loss and optimization function
def loss_function(recon_x, x, commit_loss, vq_loss):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    return recon_loss + beta * commit_loss + vq_loss, recon_loss

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4, amsgrad=True)

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    batch_sampler = BatchSampler(RandomSampler(range(len(train_noise))), batch_size=batch_size, drop_last=False)
    for batch_idx, ind in enumerate(batch_sampler):
        in_data = train_noise[ind].to(gpu)
        out_data = train_clean[ind].to(gpu)

        optimizer.zero_grad()
        recon_batch, commit_loss, vq_loss = model(in_data)
        loss, recon_loss = loss_function(recon_batch, out_data, commit_loss, vq_loss)
        loss.backward(retain_graph=True)
        model.bwd()
        optimizer.step()
        
        train_loss += recon_loss.item() / (spk_dim * spk_ch)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epoch, batch_idx * len(in_data), len(train_noise),
                100. * batch_idx / batch_cnt, recon_loss.item()))
    
    average_train_loss = train_loss / len(train_noise)
    print('====> Epoch: {} Average train loss: {:.5f}'.format(
          epoch, average_train_loss))
    return average_train_loss

In [None]:
# model logging
best_val_loss = 10
cur_train_loss = 1
def save_model(val_loss, train_loss):
	global best_val_loss, cur_train_loss
	if val_loss < best_val_loss:
		best_val_loss = val_loss
		cur_train_loss = train_loss
		torch.save(model.state_dict(), './spk_vq_vae_temp.pt')

In [None]:
def test(epoch, test_mode=True):
    if test_mode:
        model.eval()
    model.embed_reset()
    test_loss = 0
    recon_sig = torch.rand(1, spk_ch, spk_dim)
    org_sig = torch.rand(1, spk_ch, spk_dim)
    with torch.no_grad():
        batch_sampler = BatchSampler(RandomSampler(range(len(test_noise))), batch_size=test_batch_size, drop_last=False)
        for batch_idx, ind in enumerate(batch_sampler):
            in_data = test_noise[ind].to(gpu)
            out_data = test_clean[ind].to(gpu)

            recon_batch, commit_loss, vq_loss = model(in_data)
            _, recon_loss = loss_function(recon_batch, out_data, commit_loss, vq_loss)
        
            recon_sig = torch.cat((recon_sig, recon_batch.data.cpu()), dim=0)
            org_sig = torch.cat((org_sig, out_data.data.cpu()), dim=0)
        
            test_loss += recon_loss.item() / (spk_dim * spk_ch)

        average_test_loss = test_loss / len(test_noise)
        print('====> Epoch: {} Average test loss: {:.5f}'.format(
              epoch, average_test_loss))

    if epoch % 10 == 0:
        plt.figure(figsize=(7,5))
        plt.bar(np.arange(vq_num), model.embed_freq / model.embed_freq.sum())
        plt.ylabel('Probability of Activation', fontsize=16)
        plt.xlabel('Index of codewords', fontsize=16)
        plt.show()

    return average_test_loss, recon_sig[1:], org_sig[1:]

## Training

In [None]:
train_loss_history = []
test_loss_history = []

epochs = 500
start_time = time.time()

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss, _, _ = test(epoch)
    save_model(test_loss, train_loss)
    
    train_loss_history.append(train_loss)
    test_loss_history.append(test_loss)
    
print("--- %s seconds ---" % (time.time() - start_time))
print('Minimal train/testing losses are {:.4f} and {:.4f} with index {}\n'
    .format(cur_train_loss, best_val_loss, test_loss_history.index(min(test_loss_history))))

# plot train and test loss history over epochs
plt.figure(1)
epoch_axis = range(1, len(train_loss_history) + 1)
plt.plot(epoch_axis, train_loss_history, 'bo')
plt.plot(epoch_axis, test_loss_history, 'b+')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

## Result evaluation

### a. Visualization of mostly used VQ vectors

In [None]:
# select the best performing model
model.load_state_dict(torch.load('./spk_vq_vae_temp.pt'))

embed_idx = np.argsort(model.embed_freq)
embed_sort = model.embed.weight.data.cpu().numpy()[embed_idx]

# Visualizing activation pattern of VQ codes on testing dataset (the first 8 mostly activated)
plt.figure()
n_row, n_col = 1, 8
f, axarr = plt.subplots(n_row, n_col, figsize=(n_col*2, n_row*2))
for i in range(8):
    axarr[i].plot(embed_sort[i], 'r')
    axarr[i].axis('off')
plt.show()

### b. Compression ratio

In [None]:
# %% spike recon
train_mean, train_std = torch.from_numpy(d_mean), torch.from_numpy(d_std)
_, val_spks, test_spks = test(10)

# calculate compression ratio
vq_freq = model.embed_freq / sum(model.embed_freq)
vq_freq = vq_freq[vq_freq != 0]
vq_log2 = np.log2(vq_freq)
bits = -sum(np.multiply(vq_freq, vq_log2))
cr = spk_ch * spk_dim * 16 / (param_resnet_v2[2] * bits)
print('compression ratio is {:.2f} with {:.2f}-bit.'.format(cr, bits))

### c. MSE error

In [None]:
recon_spks = val_spks * train_std + train_mean
test_spks_v2 = test_spks * train_std + train_mean

recon_spks = recon_spks.view(-1, spk_dim)
test_spks_v2 = test_spks_v2.view(-1, spk_dim)

recon_err = torch.norm(recon_spks-test_spks_v2, p=2, dim=1) / torch.norm(test_spks_v2, p=2, dim=1)

print('mean of recon_err is {:.4f}'.format(torch.mean(recon_err)))
print('std of recon_err is {:.4f}'.format(torch.std(recon_err)))

### d. SNDR of reconstructed spikes

In [None]:
recon_spks_new = recon_spks.numpy()
test_spks_new = test_spks_v2.numpy()

def cal_sndr(org_data, recon_data):
    org_norm = np.linalg.norm(org_data, axis=1)
    err_norm = np.linalg.norm(org_data-recon_data, axis=1)
    return np.mean(20*np.log10(org_norm / err_norm)), np.std(20*np.log10(org_norm / err_norm))

cur_sndr, sndr_std = cal_sndr(test_spks_new, recon_spks_new)
print('SNDR is {:.4f} with std {:.4f}'.format(cur_sndr, sndr_std))

### e. Visualization of reconstructed spikes chosen at random

In [None]:
rand_val_idx = np.random.permutation(len(recon_spks_new))

plt.figure()
n_row, n_col = 3, 8
spks_to_show = test_spks_new[rand_val_idx[:n_row*n_col]]
ymax, ymin = np.amax(spks_to_show), np.amin(spks_to_show)
f, axarr = plt.subplots(n_row, n_col, figsize=(n_col*3, n_row*3))
for i in range(n_row):
    for j in range(n_col):
        axarr[i, j].plot(recon_spks_new[rand_val_idx[i*n_col+j]], 'r')
        axarr[i, j].plot(test_spks_new[rand_val_idx[i*n_col+j]], 'b')
        axarr[i, j].set_ylim([ymin*1.1, ymax*1.1])
        axarr[i, j].axis('off')
plt.show()