In [None]:
import os
import numpy as np
import mne
import osl
import scipy
import yaml
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, MaxAbsScaler, RobustScaler, StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import MiniBatchKMeans
import seaborn as sns
import numpy as np
import pickle
import matplotlib.pyplot as plt

In [None]:
clip_num = 50

In [None]:
def encode(x, pca, robust, maxabs, robust2, mu=255):
    x = robust.fit_transform(x.T)
    x = pca.fit_transform(x)
    print(np.mean(x, axis=0))
    print(np.std(x, axis=0))
    
    x = clip(x.T).T
    x = robust2.fit_transform(x)
    x = maxabs.fit_transform(x).T
    print('-------------------')
    print(np.mean(x, axis=1))
    print(np.std(x, axis=1))
    
    
    
    x = x.reshape(-1)
    x = np.sign(x)*np.log(1+mu*np.abs(x))/np.log(1+mu)
    
    bins = np.linspace(-1, 1, mu + 1)
    x = np.digitize(x, bins) - 1

    return x

In [None]:
def decode(x, pca, robust, maxabs, robust2, mu=255, shape=(306, 2857000)):
    x = x / (mu + 1) * 2 - 1
    x = np.sign(x)*((mu+1)**np.abs(x)-1) / mu
    
    x = x.reshape(shape)
    x = maxabs.inverse_transform(x.T)
    x = robust2.inverse_transform(x)
    x = pca.inverse_transform(x)
    x = robust.inverse_transform(x)
    
    return x.T

In [None]:
def clip(x):
    shape = x.shape
    
    sorted_ = np.sort(x)
    clip_vals = sorted_[:, -clip_num]
    
    for i in range(shape[0]):
        x[i, :] = np.clip(x[i, :], -clip_vals[i], clip_vals[i])

    return x

In [None]:
outdir = os.path.join('rich_data', 'subj2', 'sess4', 'oslpy_deb')

In [None]:
path = os.path.join(outdir, 'task_part1_4_tsss_preproc_raw.fif')
raw = mne.io.read_raw_fif(path, preload=True)

In [None]:
raw_data = raw.get_data(picks='meg')

In [None]:
data = clip(raw_data)

In [None]:
pca = PCA(306)
robust_scaler = RobustScaler()
robust_scaler2 = RobustScaler()
maxabs = MaxAbsScaler()

In [None]:
%matplotlib widget
plt.hist(pca_data.reshape(-1), bins=256)
#plt.xlim(-1, 1)
plt.show()

In [None]:
plt.hist(raw_data[100], bins=256)
#plt.xlim(-1, 1)
plt.show()

In [None]:
encoded = encode(raw_data, pca, robust_scaler, maxabs, robust_scaler2)

In [None]:
%matplotlib widget
plt.plot(encoded[50000:55000])

In [None]:
encoded = encoded.reshape((306, 2857000))

In [None]:
%matplotlib widget
plt.hist(encoded[1], bins=256)
plt.show()

In [None]:
%matplotlib widget
plt.hist(encoded[-1], bins=256)
plt.show()

In [None]:
decoded = decode(encoded.reshape(-1), pca, robust_scaler, maxabs, robust_scaler2)

In [None]:
plt.hist(decoded, bins=256)
plt.show()

In [None]:
errors = []
for i in range(306):
    data_mean = np.mean(np.abs(raw_data[i]))
    diff = np.mean(np.abs(raw_data[i] - decoded[i]))
    
    errors.append(diff/data_mean)

In [None]:
%matplotlib widget
plt.hist(errors, bins=306)
plt.show()

In [None]:
np.mean(np.abs(data - decoded))

In [None]:
np.mean(np.abs(data))

In [None]:
diff = data - decoded

In [None]:
%matplotlib widget
plt.plot(data[10, 50000:52000], 'r', decoded[10, 50000:52000], 'b', linewidth=0.2)

In [None]:
%matplotlib widget
plt.hist(diff.reshape(-1), bins=10000)
plt.xlim((-2e-12, 2e-12))

In [None]:
errors = np.arange(256) / 256

In [None]:
np.mean(errors**2)

In [None]:
path = os.path.join('..', 'results', 'cichy_epoched', 'subj1', 'cont_quantized', 'wavenetfulltest', 'preds.npy')
preds = np.load(path)

path = os.path.join('..', 'results', 'cichy_epoched', 'subj1', 'cont_quantized', 'wavenetfulltest', 'targets.npy')
targets = np.load(path)

In [None]:
preds.shape

In [None]:
preds = preds.transpose(1, 0, 2).reshape(preds.shape[1], -1)
targets = targets.transpose(1, 0, 2).reshape(targets.shape[1], -1)

In [None]:
np.mean((targets[:, :, 1:].reshape(-1) - targets[:, :, :-1].reshape(-1))**2)

In [None]:
np.mean((preds[:, :, 1:].reshape(-1) - targets[:, :, :-1].reshape(-1))**2)

In [None]:
mse

In [None]:
%matplotlib widget
plt.plot(preds[0, 0, 1:60], linewidth=0.5)
plt.plot(targets[0, 0, :60], linewidth=0.5)

In [1]:
np.min(np.abs(preds.reshape(-1)))

NameError: name 'np' is not defined

In [None]:
sum(np.abs(targets.reshape(-1))<0.075)/np.prod(targets.shape)

In [None]:
%matplotlib widget
plt.scatter(np.arange(len(np.unique(preds))), np.unique(preds), s=1)
#plt.plot(np.unique(targets))

In [None]:
%matplotlib widget
plt.hist(targets.reshape(-1), bins=1000)
plt.show()

In [None]:
path = os.path.join('..', 'results', 'cichy_epoched', 'subj1', 'cont_quantized', 'dumb_conv', 'inputs.npy')
inputs = np.load(path)

path = os.path.join('..', 'results', 'cichy_epoched', 'subj1', 'cont_quantized', 'dumb_conv', 'targets_full.npy')
targets = np.load(path)

In [None]:
inputs = inputs.transpose(1, 0, 2).reshape(inputs.shape[1], -1)
targets = targets.transpose(1, 0, 2).reshape(targets.shape[1], -1)

In [None]:
%matplotlib widget
plt.plot(inputs[0, 1:10001], linewidth=0.5)
plt.plot(targets[0, :10000], linewidth=0.5)

In [None]:
%matplotlib widget
plt.scatter(np.arange(len(np.unique(targets))), np.unique(targets), s=1)

In [None]:
%matplotlib widget
plt.hist(inputs.reshape(-1), bins=400)
plt.show()