# Robust-NTF Applied to Missing Data

Here, we generate a synthetic low-rank 3-dimensional tensor from known signals. Some of the data is removed (set to NaN) mirroring hyperspectral autofluorescence image cubes. The data is processed

## Setup

In [None]:
from decimal import Decimal
import torch
import numpy as np
import tensorly as tl
import matplotlib.pyplot as plt
import sys
from torch.nn.functional import normalize
from scipy import signal
from scipy.stats import gamma
from tensorly.kruskal_tensor import kruskal_to_tensor
from tensorly.decomposition.candecomp_parafac import non_negative_parafac
from tensorly.tenalg.outer_product import outer

sys.path.append("..")
from robust_ntf.robust_ntf import RntfConfig, RobustNTF, RntfStats

# Use the GPU at fp64 by default:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)

# Make TensorLy use PyTorch:
tl.set_backend('pytorch')

# Set RNG seeds:
torch.manual_seed(33)
np.random.seed(33)

# Set an epsilon to protect against zeros:
eps = 1e-6

def fexp(number):
    (sign, digits, exponent) = Decimal(number).as_tuple()
    return len(digits) + exponent - 1

def fman(number):
    return Decimal(number).scaleb(-fexp(number)).normalize()

## Part 1: Generate synthetic tensor

### Generate ground truth factors:

Over here, we generate ground truth factor matrices to generate a rank-3 synthetic tensor with. They include,

* A Gaussian modulated sinusoid and take its real and imaginary parts, and its envelope to be the ground truth factors.
* Three different chirp signals.
* Three different Gamma PDFs.

In [None]:
#######################
## Mode-1 generation ##
#######################

# Sample 50 points:
mode1_support = np.linspace(-1, 1, 2*25, endpoint=False)

# Generate signal and plot:
x1, x2, x3 = signal.gausspulse(mode1_support, fc=3,
                               retquad=True, retenv=True)
x1 = 2 * np.abs(x1)
x2 = 2 * np.abs(x2)
x3 = 2 * np.abs(x3)

#######################
## Mode-2 generation ##
#######################

mode2_support = np.linspace(-1, 1, 96, endpoint=False)
y1 = signal.chirp(mode2_support, f0=4, t1=-0.5, f1=4)
y2 = signal.chirp(mode2_support, f0=2, t1=0.5, f1=3)
y3 = signal.chirp(mode2_support, f0=1, t1=0.1, f1=2)

y1 = y1 - y1.min()
y2 = y2 - y2.min()
y3 = y3 - y3.min()

#######################
## Mode-3 generation ##
#######################

mode3_support = np.linspace(0, 10, 20)

z1 = gamma(7).pdf(mode3_support)
z2 = gamma(2).pdf(mode3_support)
z3 = gamma(4).pdf(mode3_support)

### Plot ground truth factors:

In [None]:
# Set up figure size:
fig = plt.figure(figsize=(15,8))

# Plot factors:
plt.subplot(2,2,1)
plt.plot(mode1_support, x1,
         mode1_support, x2,
         mode1_support, x3)
plt.gca().set_title('Mode-1 factors')

plt.subplot(2,2,2)
plt.plot(mode2_support, y1,
         mode2_support, y2,
         mode2_support, y3)
plt.gca().set_title('Mode-2 factors')

plt.subplot(2,2,3)
plt.plot(mode3_support, z1,
         mode3_support, z2,
         mode3_support, z3)
plt.gca().set_title('Mode-3 factors')

### Cast factors to PyTorch and/or make positive:

In [None]:
# Mode-1:
X = np.array([x1, x2, x3])
X = torch.from_numpy(X).cuda() + eps

# Mode-2:
Y = np.array([y1, y2, y3])
Y = torch.from_numpy(Y).cuda() + eps

# Mode-3:
Z = np.array([z1, z2, z3])
Z = torch.from_numpy(Z).cuda() + eps

### Construct ground truth tensor to factorize:

In [None]:
# Construct Kruskal tensor in TensorLy format:
ktens = (None, [X.t(), Y.t(), Z.t()])

# Construct dense tensor:
data = kruskal_to_tensor(ktens)

In [None]:
for i in range(data.shape[-1]):
    data[-int(1.5*i)-1:, :, i+1:] = np.nan
np.isnan(data.cpu().numpy()).sum() / data.cpu().numpy().size

### Visualize some slices of the tensor in false color:

In [None]:
fig = plt.figure(figsize=(15,8))

# XY
plt.subplot(2,2,1)
XY = data[:, :, [0, 5, 10]].data.cpu().numpy()
XY = XY / np.nanmax(XY)
XY[np.isnan(XY)] = 1.0
plt.imshow(XY)

# XZ
plt.subplot(2,2,2)
XZ = data[:, [0, 5, 10], :].data.cpu().numpy()
XZ = XZ.transpose([0, 2, 1])
XZ = XZ / np.nanmax(XZ)
XZ[np.isnan(XZ)] = 1.0
plt.imshow(XZ)

# ZY
plt.subplot(2,2,3)
ZY = data[[10, 15, 20], :, :].data.cpu().numpy()
ZY = ZY.transpose([2, 1, 0])
ZY = ZY / np.nanmax(ZY)
ZY[np.isnan(ZY)] = 1.0
plt.imshow(ZY)

## Part 2: Compare methods

Run the cells below with error tolerance 1e-2 and then 1e-3 and compare.

In [None]:
ERROR_TOLERANCE = 1e-3

cfg = RntfConfig(3, 2, 0.1, ERROR_TOLERANCE, max_iter=200000, print_every=100, save_every=100, save_folder="./out")
rntf = RobustNTF(cfg)
rntf.run(data)
rntf_01_factors = rntf.matrices
rntf_01_outlier = rntf.outlier
vals = rntf.stats

## Visualize objective, error, and reconstruction accuracy statistics.

At 1e-3 local minima should be visible that dip below 1e-2, which would cause early stopping at 1e-2. Red dots indicate local minima in the relative error change curve. This is intended to be a demonstration of early stopping given an unsuitable choice of error tolerance as a means of identifying convergence. Accuracy metrics (L2 and L_inf) are also displayed for reference and comparison.

In [None]:
fig = plt.figure(figsize=(8,15))

inds = list(range(len(vals["error"])))
x_start = np.log10(inds[1]) + 0.02
x_end = np.log10(inds[-1]) - 0.02

plt.subplot(4,1,1)
plt.ylim(top=5, bottom=-20)
plt.xlim(left=0, right=3.5)
obj = vals[RntfStats.OBJ].to_numpy()
plt.plot(np.log10(inds), np.log10(obj))
plt.annotate("Objective", xy=(x_start, np.log10(obj[1])+1), horizontalalignment="left")
fit = vals[RntfStats.FIT].to_numpy()
plt.plot(np.log10(inds), np.log10(fit), linestyle="dashed", color="gray")
plt.annotate("Fitness\n(Beta Divergence)", xy=(x_start, np.log10(fit[1])-1), horizontalalignment="left", verticalalignment="top")
reg = vals[RntfStats.REG].to_numpy()
plt.plot(np.log10(inds), np.log10(reg), linestyle=":", color="gray")
plt.annotate("Regularization\nTerm\n($L_{2,1}$ Norm)", xy=(x_end-0.2, np.log10(reg[-1])-0.5), horizontalalignment="left", verticalalignment="top")
plt.title("Objective Function")
sreg = signal.savgol_filter(np.log10(reg), window_length=501, polyorder=3)
dreg = np.log10(np.diff(np.diff(sreg)))
peak = np.nanargmax(dreg)
x = np.log10(inds)[peak]
plt.plot([x, x], plt.ylim(), "k:")
c_x_pos = np.log10(peak) + 0.03
plt.annotate("Convergence", xy=(c_x_pos, 2), horizontalalignment="left", verticalalignment="center")

plt.subplot(4,1,2)
plt.ylim(top=0.5, bottom=-3.5)
plt.xlim(left=0, right=3.5)
err = vals[RntfStats.ERR].to_numpy()
plt.plot(np.log10(inds), np.log10(err))
peaks = signal.find_peaks(-err)[0].tolist()
plt.plot(np.log10(inds)[peaks], np.log10(err)[peaks], "r.")
plt.plot([x, x], plt.ylim(), "k:")
plt.annotate("Convergence", xy=(c_x_pos, 0), horizontalalignment="left", verticalalignment="center")
plt.plot(plt.xlim(), [-2, -2], "k:")
plt.annotate("Early stopping below this line\nif tolerance set to $0.01$", xy=(0.03, -2.1), horizontalalignment="left", verticalalignment="top")
plt.title("Relative Change in Objective Function (Error)")

plt.subplot(4,1,3)
plt.xlim(left=0, right=3.5)
L2_acc = vals[RntfStats.L2_ACC].to_numpy()
plt.plot(np.log10(inds[0:]), np.log10(L2_acc[0:]))
plt.title("Accuracy ($L_{2}$ Norm)")

plt.subplot(4,1,4)
plt.xlim(left=0, right=3.5)
Linf_acc = vals[RntfStats.LINF_ACC].to_numpy()
plt.plot(np.log10(inds[0:]), np.log10(Linf_acc[0:]))
plt.title("Accuracy ($L_{inf}$ Norm)")

fig.tight_layout()

## Run early stopping.

In [None]:
ERROR_TOLERANCE = 1e-2

cfg2 = RntfConfig(3, 2, 0.1, ERROR_TOLERANCE, max_iter=200000, print_every=100, save_every=100, save_folder="./out")
rntf2 = RobustNTF(cfg2)
rntf2.run(data)
rntf2_01_factors = rntf2.matrices
rntf2_01_outlier = rntf2.outlier
vals2 = rntf2.stats

## Plot some factors:
Here, mode-3 factors for NTF and rNTF are plotted. There are considerable differences between 1e-2 and 1e-3.

In [None]:
# Set up figure size:
fig = plt.figure(figsize=(15, 8))

y1 = normalize(rntf_01_factors[2], dim=0).data.cpu().numpy()
h1 = plt.plot(y1, color="gray", linestyle=":")
y2 = normalize(rntf2_01_factors[2], dim=0).data.cpu().numpy()
h2 = plt.plot(y2, color="k")
plt.gca().set_title('Mode-3 Results')
plt.legend(handles=[h1[0], h2[0]], labels=["Tolerance = $10^{-3}$", "Tolerance = $10^{-2}$"])

## Visualize outliers

At 1e-2 the outliers are significantly larger than at 1e-3.

In [None]:
## Plot results:
# Set up figure size:
fig = plt.figure(figsize=(15, 15))
sl = 25
tr = [0, 2, 1]
eps = np.nextafter(0, 1)

# Plot rNTF reconstruction:
plt.subplot(1,2,1)
XZr = rntf_01_outlier[:, sl, :].data.cpu().numpy()
m = np.nanmax(XZr)
XZr = np.log10(XZr) + 20
XZr[np.isnan(XZr) | np.isinf(XZr)] = 0.0
plt.imshow(XZr)
m_str = "${:.2f}\\times10^{{{:d}}}$".format(fman(m), fexp(m))
plt.gca().set_title('Tolerance = {:s} Reconstruction\nMax = {:s}'.format("$10^{-3}$", m_str))
plt.clim(0, 20)

# Plot rNTF reconstruction:
plt.subplot(1,2,2)
XZr = rntf2_01_outlier[:, sl, :].data.cpu().numpy()
m = np.nanmax(XZr)
XZr = np.log10(XZr) + 20
XZr[np.isnan(XZr) | np.isinf(XZr)] = 0.0
plt.imshow(XZr)
m_str = "${:.2f}\\times10^{{{:d}}}$".format(fman(m), fexp(m))
plt.gca().set_title('Tolerance = {:s} Reconstruction\nMax = {:s}'.format("$10^{-2}$", m_str))
plt.clim(0, 20)