In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from scipy.optimize import nnls
from sklearn.decomposition import FastICA
from matplotlib.colors import ListedColormap
from sklearn.decomposition import NMF
%matplotlib inline
from utils import *
from vca import *
from model.nmf import NMFGD
from model.nnica import *

## Simulation Training Data
```python
wave_list = [750, 760, 800, 850, 900, 925]
wave_list = [750, 760, 800, 850, 900, 910, 920, 930, 940, 950]
```

In [None]:
wave_list = [750, 760, 800, 850, 900, 910, 920, 930, 940, 950]
abs_coeff, legend = {}, ["HbO2", "Hb", "Cholesterol", "Background"]
wave_abs = np.load('./data/hbo2hbchpr_57.npy')

f = loadmat("./data/3DPlot.mat")
X, Y = f['x'], f['y']

for idx, wave in enumerate(np.arange(700, 981, 5)):
    abs_coeff[wave] = (idx, wave_abs[idx])
coeffs = np.vstack([abs_coeff[wave][1] for wave in wave_list])
normcoeffs = np.array([coeffs.T[idx] / max(coeffs.T[idx]) for idx in range(coeffs.shape[1])]).T
weights_plot(array = coeffs[:, 0:], wave_list = wave_list, legend = legend, figsize = (9, 6), xticks = wave_list)

In [None]:
x = np.arange(700, 981, 5)
y = np.array(np.vstack([abs_coeff[wave][1] for wave in x])[:, 0:3])
plt.figure(figsize = (10, 6))
plt.plot(x, y)
plt.ylim([1e-4, 10])
plt.yscale('log')
plt.legend('')
plt.xticks(np.arange(700, 981, ))
plt.show()

## Linear Unmixing

In [None]:
depth = 25
hbhbo2fat = np.copy(coeffs)[:, 0:3]
sim_data = np.array([np.array(loadmat(f"./data/hb_hbo2_fat_29_{depth}/PA_Image_{wave}.mat")['Image_PA']) for wave in wave_list])

In [None]:
unmixed = run_linear_unmixing(normalize(sim_data.copy()), hbhbo2fat)
plot_comps_2d(unmixed, wave_list, hbhbo2fat, clim = [None]*3, xticks = wave_list)

In [None]:
plot_3d_multiple(Y*1000, X*1000, unmixed, title = legend[:3], cmap = 'jet', clim = [None]*3, order = [0, 1, 2])

In [None]:
plt.figure(figsize = (20, 12))
for i in range(sim_data.shape[0]):
    plt.subplot(2, sim_data.shape[0] // 2, i+1)
    plt.imshow(sim_data[i], cmap = "jet")
    plt.title(f'{wave_list[i]} nm')
    plt.colorbar()
    plt.clim([0, 0.005])

## NMF

In [None]:
nmf_model = NMF(n_components = 3)
nmf_model.fit(normalize(sim_data.copy()).reshape(len(wave_list), -1).T)

In [None]:
nmf_comps = nmf_model.fit_transform(normalize(sim_data.copy()).reshape(len(wave_list), -1).T).reshape((396, 101, 3))
plot_comps_2d(nmf_comps, wave_list, nmf_model.components_.T, order = [0, 1, 2], clim = [None]*3)

In [None]:
plot_3d_multiple(Y*1000, X*1000, nmf_comps, title = legend[:3], cmap = 'jet', clim = [[0, np.max(nmf_comps)]]*3, order = [0, 1, 2])

In [None]:
weights_plot(array = nmf_model.components_.T, wave_list = wave_list, legend = legend, figsize = (9, 6), xticks = wave_list)

### NMF Gradient Descent

In [None]:
nmf_model_test = NMFGD(n_components = 3, randominit = True)
nmf_model_test.fit(normalize(sim_data.copy()).reshape((len(wave_list), -1)), maxiter = 1500)

In [None]:
plot_comps_2d(nmf_model_test.H.T.reshape((396, 101, 3)), wave_list, nmf_model_test.W, order = [2, 0, 1], clim = [None]*3)

## VCA

In [None]:
ae, ind, yp = vca(sim_data.copy().reshape(len(wave_list), -1), 3)
vca_comps = np.matmul(np.linalg.pinv(ae), yp).reshape((3, 396, 101)).transpose((1, 2, 0))
plot_comps_2d(vca_comps, wave_list, ae, order = [2, 0, 1], clim = [None, None, None])

In [None]:
vca_order = [2, 0, 1]
plot_3d_multiple(Y*1000, X*1000, vca_comps, title = legend[:3], cmap = 'jet', clim = [None]*3, order = vca_order)

In [None]:
weights_plot(array = ae[:, vca_order], wave_list = wave_list, legend = legend, figsize = (9, 6), xticks = wave_list)

## FastICA

```python
from scipy import linalg
XW = sim_data.reshape((10, 396*101))
X_mean = XW.mean(axis = -1)
XW -= X_mean[:, np.newaxis]
U, D = linalg.svd(XW, full_matrices = False, check_finite = False)[:2]
U *= np.sign(U[0])
K = (U / D).T[:3]
XW = np.dot(K, XW)
XW *= np.sqrt(396*101)
XW = XW.reshape((3, 396, 101)).transpose((1, 2, 0))
```

In [None]:
for i in range(20):
    print(f"Random State: {i}")
    maps, wts, _ = run_ica(sim_data.copy(), wave_list, 3, i, algorithm = 'deflation')
    plot_comps_2d(maps, wave_list, wts, figsize = (10, 3), order = [0, 1, 2])

In [None]:
order = [2, 1, 0]
maps, wts, model = run_ica(sim_data.copy(), wave_list, 3, 0, algorithm = 'parallel')
plot_comps_2d(maps, wave_list, np.linalg.pinv(model.components_), xticks = wave_list, clim = [None]*3, order = order)

In [None]:
plot_3d_multiple(Y*1000, X*1000, maps, title = legend[:3], cmap = 'jet', clim = [[0, np.max(maps)]]*3, order = order)

### Non-Negative ICA

In [None]:
train_data = sim_data.copy().reshape((10, -1)).T
nnmdl = NNICA(n_components = 3)
nnmdl.fit(train_data)

In [None]:
depth = 40
hbhbo2fat = np.copy(coeffs)[:, 0:3]
sim_data_test = np.array([np.array(loadmat(f"../data/hb_hbo2_fat_29_{depth}/PA_Image_{wave}.mat")['Image_PA']) for wave in wave_list])
sim_data_test = sim_data_test.reshape((10, 396 * 101)).T
sim_data_test -= model.mean_
test_comps = np.matmul(model.components_, sim_data_test.T).reshape((3, 396, 101)).transpose((1, 2, 0))
plot_comps_2d(test_comps, wave_list, np.linalg.pinv(model.components_), order = order, clim = [None]*3)

In [None]:
plot_3d_multiple(Y*1000, X*1000, test_comps, title = legend[:3], cmap = 'jet', clim = [None]*3, order = order)

## 10 Wavelengths Experimental Results

```python
f = loadmat("./expdata/All Animal Results to Date/CONTROL MICE/03.31 Exp 2 (10 WV)/ALL_FRAMES_MOTIONREMOVED.mat")

f = loadmat("./expdata/All Animal Results to Date/CONTROL MICE/03.31 Exp 7 (10 WV)/SUMMARY_DATA.mat")

f = loadmat("./expdata/All Animal Results to Date/DKO MICE/01.20 Exp 19 (10 WV)/SUMMARY_FRAMES.mat")

f = loadmat("./expdata/All Animal Results to Date/EX VIVO SKIN SAMPLES/01.28 Exp 23/SUMMARY_FRAMES.mat")
```

## 6 Wavelengths Experimental Results

```python
f = loadmat("./expdata/All Animal Results to Date/EX VIVO SKIN SAMPLES/01.28 Exp 22/SUMMARY_FRAMES.mat")
```

In [None]:
os.listdir('./expdata/All Animal Results to Date/APOE MICE')

In [None]:
f = loadmat("./expdata/All Animal Results to Date/APOE MICE/01.28 Exp 16/SUMMARY_FRAMES.mat")
exp_wave_list = [int(key[-3:]) for key in f.keys() if key[:3] == 'all']
exp_img = np.array([np.mean(f[f'all{wave}'], axis = 2) for wave in exp_wave_list]).transpose((1, 2, 0))[55:125, 50:150, :]
exp_img = exp_img.transpose((2, 0, 1))

exp_img_plot = normalize(exp_img.copy())
plt.figure(figsize = (24, 7))
for i, wave in enumerate(exp_wave_list):
    plt.subplot(2, len(wave_list) // 2, i+1)
    plt.imshow(exp_img_plot[i], cmap = "hot")
    plt.title(f"{wave} nm")
    plt.colorbar()
plt.show()
del exp_img_plot

### Linear Unmixing

In [None]:
exp_unmixed = run_linear_unmixing(normalize(exp_img.copy()), hbhbo2fat)
plot_comps_2d(exp_unmixed, wave_list, hbhbo2fat, "Linear Unmixing", (18, 4), order = [0, 1, 2], xticks = wave_list, clim = [None]*3)

### FastICA

In [None]:
exp_test_data = normalize(exp_img.copy()).transpose((1, 2, 0)).reshape((-1, len(wave_list)))
exp_test_data -= model.mean_
exp_test_comps = np.matmul(model.components_, exp_test_data.T).T.reshape((exp_img.shape[1], exp_img.shape[2], 3))
exp_test_comps[exp_test_comps < 0] = 0
plot_comps_2d(exp_test_comps, wave_list, np.linalg.pinv(model.components_), "ICA", (18, 4), order = order, clim = [None, None, None], xticks = wave_list)

### NMF

In [None]:
nmf_model_exp = NMF(n_components = 3)
exp_test_data = normalize(exp_img.copy()).transpose((1, 2, 0)).reshape((-1, len(wave_list)))
nmf_test_comps = nmf_model_exp.fit_transform(exp_test_data).reshape((exp_img.shape[1], exp_img.shape[2], 3))
plot_comps_2d(nmf_test_comps, wave_list, np.linalg.pinv(nmf_model_exp.components_), order = [0, 1, 2], clim = [None]*3, figsize = (18, 4))

### NMF GD

In [None]:
nmf_model_test_exp = NMFGD(n_components = 3, randominit = True)
nmf_model_test_exp.fit(normalize(exp_img.copy()).reshape((len(wave_list), -1)), maxiter = 1200)

In [None]:
plot_comps_2d(nmf_model_test_exp.H.T.reshape((exp_img.shape[1], exp_img.shape[2], 3)), wave_list, nmf_model_test_exp.W, order = [0, 1, 2], clim = [None]*3)

### VCA

In [None]:
vca_test_comps = np.matmul(np.linalg.pinv(ae), exp_test_data.copy().T).reshape((3, exp_img.shape[1], exp_img.shape[2])).transpose((1, 2, 0))
vca_test_comps[vca_test_comps < 0] = 0
plot_comps_2d(vca_test_comps.copy(), wave_list, ae, order = [2, 0, 1], clim = [None, None, None], figsize = (18, 4))

## Constrained AE

In [None]:
from utils import *
from datasets import ZCA
from unmix_constrained import *
from datasets import SingleCholesterolDataset

In [None]:
BATCH_SIZE = 396*101
EPOCHS = 150
LRATE = 7e-3
NCOMP = 3
SEED = 9
negexp = 1.5
beta = 1
torch.manual_seed(seed = SEED)
data = SingleCholesterolDataset(root = './data/hb_hbo2_fat_11', wavelist = 'EXP10', depth = [20], whiten = 'zca', normalize = True)
dataloader = DataLoader(data, batch_size = BATCH_SIZE, shuffle = False, num_workers = 16)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoEncoder(len(data.wavelist), NCOMP, activation = 'tsigmoid', tied = True).to(device = device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LRATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.5, patience = 10, verbose = True)
losses = []
model.train()
for epoch in (t := trange(EPOCHS)):
    epochloss = []
    for batch in dataloader:
        epochloss = []
        batch = batch.to(device)
        optimizer.zero_grad()
        encoded, decoded = model.forward(batch)
        mse = criterion(decoded, batch)
        negentropy = torch.abs(-(torch.mean(-torch.exp(- negexp * (encoded ** 2) / 2)) - torch.mean(-torch.exp(- negexp * (torch.randn_like(encoded)) ** 2 / 2))))
        loss = ((beta * mse) + ((1 - beta) * negentropy)) if beta != 1 else (mse + negentropy)
        loss.backward()
        optimizer.step()
        epochloss.append(loss.item())
        t.update(1)
        t.set_description_str(f'EPOCH: [{epoch + 1}/{EPOCHS}]')
        t.set_postfix_str(f'MSELOSS: {mse.item():.3f} NEGENTROPY: {negentropy.item():.3f}')
    epochlossmean = sum(epochloss) / len(epochloss)
    scheduler.step(epochlossmean)
    losses.append(epochlossmean)
    if epochlossmean < 0.02:
        break

In [None]:
data = SingleCholesterolDataset(root = './data/hb_hbo2_fat_11', wavelist = 'EXP10', depth = [20], whiten = 'zca', normalize = True)
imgdata = data[0].clone().numpy().T.reshape((10, 396, 101))
plt.figure(figsize = (20, 4))
for idx in range(data[0].shape[-1]):
    plt.subplot(1, 10, idx + 1)
    plt.imshow(imgdata[idx], cmap = 'hot')
    plt.colorbar()
plt.show()

In [None]:
imgdata = decoded[0].cpu().detach().numpy().T.reshape((10, 396, 101))
plt.figure(figsize = (20, 4))
for idx in range(data[0].shape[-1]):
    plt.subplot(1, 10, idx + 1)
    plt.imshow(imgdata[idx], cmap = 'hot')
    plt.colorbar()
plt.show()

In [None]:
plt.figure(figsize = (8, 5))
plt.plot(list(range(len(losses))), losses)
plt.show()

In [None]:
sim_data = np.array([np.array(loadmat(f'./data/hb_hbo2_fat_11_20/PA_Image_{wave}.mat')['Image_PA']) for wave in data.wavelist])
c, h, w = sim_data.shape
sim_data = sim_data.transpose((1, 2, 0)).reshape((h*w, c))
zca = ZCA()
sim_data = zca.fit_transform(sim_data)
sim_data = torch.tensor(np.expand_dims(sim_data, axis = 0), dtype = torch.float32)
preds = np.array(model.encode(sim_data.to(device)).cpu().detach())[0].reshape((h, w, NCOMP))

plot_comps_2d(preds, data.wavelist, model.encw.detach().cpu().numpy(), order = [0, 1, 2], xticks = data.wavelist, title = 'CONSTRAINED AE', save = False, chrom = list(range(3)))