In [None]:
from dalila.dictionary_learning import StabilityDictionaryLearning,DictionaryLearning
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from joblib import Parallel, delayed
import multiprocessing


In [None]:
from utils import *

# Load data

In [None]:
from scipy.io import loadmat
from unicodedata import normalize
filename = "data/breast_cancer_data.mat"
data = loadmat(filename, appendmat=False)
v = data["originalGenomes"]
types = data["types"]
l = len(types)
types_1 = [None] * l
for i in range(0, l):
    types_1[i] = normalize('NFKD', types[i][0][0]).encode('ascii','ignore')
data = v.T
types = np.asarray(types_1)

Remove "weak" mutations from the dataset

In [None]:
res = remove_weak_mutations(data, 0.01)
X = res["mutational_catalogue"]
removed_cols = res["removed_cols"]

# Extract mutational signatures
For each possible number of signatures it extracts the dictionary and the coefficients using Non-negative Matrix Factorization and clustering the atoms until convergence. 

In [None]:
def process_input(k, X):
    print("NUMBER OF ATOMS: ", k)
    estimator = StabilityDictionaryLearning(k=k, non_negativity="both")
    estimator.fit(X)
    C, D = estimator.decomposition()
    return C, D, estimator.stability, np.sum((X - C.dot(D))**2)
    

In [None]:
num_cores = multiprocessing.cpu_count()
possible_atoms = np.arange(2,9)
results = Parallel(n_jobs=num_cores)(delayed(process_input)(k, X) for k in possible_atoms)

## Plot the stabilities and the reconstruction errors

In [None]:
Cs = []
Ds = []
stabilities = []
errors = []
for k in range(len(results)):
    Cs.append(results[k][0])
    Ds.append(results[k][1])
    stabilities.append(results[k][2])
    errors.append(results[k][3])

In [None]:
fig, ax1 = plt.subplots(figsize=(5,5))
markers_on = [4]
ax1.plot(np.arange(2,9), stabilities, '-bD', markevery=markers_on, label="Stability")
# Make the y-axis label, ticks and tick labels match the line color.
ax1.tick_params('y', colors='b')
ax1.set_xlabel('Number of mutational signatures')
ax2 = ax1.twinx()
ax2.plot(np.arange(2,9), erros, '-rD', markevery=markers_on, label="Reconstruction error")
ax2.tick_params('y', colors='r')

ax1.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3)
ax2.legend(bbox_to_anchor=(0.4, 1.02, 1., .102), loc=3)
plt.rcParams['axes.facecolor'] = (0.294, 0.294, 0.5, 0.3)
fig.tight_layout()
plt.show()

Select the best number given the plot

In [None]:
D = Ds[4]
C = Cs[4]

Re-insert the removed "weak" columns

In [None]:
D = add_removed_cols(D, removed_cols)
D_ordered = ordering_for_types(D, types)

# Plot the resulting atoms

In [None]:
for i in range(D.shape[0]):
    plot_atom(our_D[i,:])

# Analysis of the coefficients 

In [None]:
percentages = np.zeros_like(C)

for sample in range(C.shape[0]):
    total = np.sum(C[sample,:])
    if(total != 0):
        percentages[sample,:] = C[sample, :] / total 

print(percentages)

In [None]:
percentages[np.where(percentages<0.25)] = 0
frequencies = np.zeros(D.shape[0])
for atom in range(percentages.shape[1]):
    frequencies[atom]= len(np.nonzero(percentages[:,atom])[0])
plt.figure(figsize=(15,10))
plt.hist(np.arange(D.shape[0]),weights=frequencies, bins=D.shape[0], orientation="horizontal");
plt.show()