In [10]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def gen_pc_data(N, M, R, percent_moonlight):
    
    # First generate W
    W = np.zeros((N,R))
    moonlight_genes = np.random.uniform(0,1,N)
    moonlight_genes = moonlight_genes<percent_moonlight
    
    # Pick the modules for each gene
    for gene_idx in range(N):
        if moonlight_genes[gene_idx]:
            num_mods = np.random.randint(2,R+1)
            gene_mods = np.random.randint(0,R,num_mods)
            W[gene_idx, gene_mods] = 1
        else: 
            gene_mod = np.random.randint(0,R)
            W[gene_idx, gene_mod] = 1
    # Pick the relative expression levels of each gene for each module
    for mod_idx in range(R):
        # Handel the case where a module has no genes (very unlikely for N >> R)
        if W[:,mod_idx].sum() == 0:
            W[np.random.randint(0,N),mod_idx] = 1
            
        else:
            col_arr = W[:,mod_idx]
            col_sum = int(col_arr.sum())
            dir_params = np.ones(col_sum)
            gene_expressions_arr = np.random.dirichlet(dir_params)
            idx_arr = np.argwhere(col_arr != 0)
            col_arr[idx_arr] = gene_expressions_arr[:,np.newaxis]
            W[:, mod_idx] = col_arr
            
    # Then generate H
    H = np.ones((R,M))
    H = np.apply_along_axis(np.random.dirichlet, 0, H)
    
    # Now generate V from W and H
    V = 1e6*np.matmul(W, H)
    V = np.random.poisson(V)
            
    return W, H, V

In [3]:
# Generate simulated data where we are observing 100 genes in 60 samples with 5 gene batteries and ~5 percent
# of genes are moonlighting genes
W1_true, H1_true, V1 = gen_pc_data(100,60,3,0.2)
W2_true, H2_true, V2 = gen_pc_data(100,60,3,0.05)
W3_true, H3_true, V3 = gen_pc_data(100,60,3,0)

In [4]:
def iterate_NMF(V, R, N):
    
    iteration_num_list = []
    dims = V.shape
    C = V.sum(axis=0)
    C = C[np.newaxis,:]
    
    best_LL = -1
    best_W = -1
    best_H = -1
    
    for epoch in range(N):
        
        # Initiate random W and H
        epoch_W, epoch_H = np.ones((dims[0],R)), np.ones((R,dims[1]))
        epoch_W = np.apply_along_axis(np.random.dirichlet, 0, epoch_W)
        epoch_H = np.apply_along_axis(np.random.dirichlet, 0, epoch_H)
        epoch_LL = -1

        for iteration in range(100000):

            # Get matrices for W and H update
            Lamda = np.multiply(C,np.matmul(epoch_W,epoch_H))
            P = np.divide(V, Lamda)
            CH = np.multiply(C,epoch_H)

            # Broadcast into higher dimensions and perform element-wise multiplication for W update
            P_3d = P[:,np.newaxis,:]
            CH_3d = CH[np.newaxis,:,:]
            L_W = np.multiply(P_3d,CH_3d).sum(axis=2)

            # Get new W
            W_new = np.multiply(epoch_W,L_W)
            norm_term = W_new.sum(axis=0)[np.newaxis,:]
            W_new = np.divide(W_new,norm_term)

            # Broadcast into higher dimensions and perform element-wise multiplication for H update
            W_3d = epoch_W[:,:,np.newaxis]
            L_H = np.multiply(P_3d,W_3d).sum(axis=0)

            # Get new H
            H_new = np.multiply(epoch_H,L_H)

            # Update W and H
            epoch_W = W_new
            epoch_H = H_new
            
            # Calculate the negative log likelyhood and end iteration if it has not changed significantly
            if iteration%100 == 0:
                iteration_LL = np.sum(np.subtract(np.multiply(V, np.log(Lamda)), Lamda))
                
                if epoch_LL == -1:
                    epoch_LL = iteration_LL
                    
                elif np.abs(np.subtract(iteration_LL, epoch_LL)) < 0.1:
                    epoch_LL = iteration_LL
                    iteration_num_list.append(iteration)
                    break
                
                else: epoch_LL = iteration_LL
         
        # Update the best parameters if the log-likelyhood is better
        if epoch_LL > best_LL or best_LL == -1:
            best_LL = epoch_LL
            best_W = epoch_W
            best_H = epoch_H

    return best_W, best_H, best_LL, iteration_num_list

def find_best_fit_r(V, R_range_vec, N):
    W_list = []
    H_list = []
    LL_list = []
    
    for R in range(R_range_vec[0],R_range_vec[1]+1):
        results = iterate_NMF(V, R, N)
        W_list.append(results[0])
        H_list.append(results[1])
        LL_list.append(results[2])
        
    return W_list, H_list, LL_list
    

In [5]:
test_1 = find_best_fit_r(V1, (3,6), 5)
test_2 = find_best_fit_r(V2, (3,6), 5)
test_3 = find_best_fit_r(V3, (3,6), 5)

In [6]:
# So for these r=3 definetly looks better...
print(np.abs(1e6*np.matmul(test_1[0][0],test_1[1][0])-V1))
print(np.abs(1e6*np.matmul(test_1[0][3],test_1[1][3])-V1))
# But the sum for r=6 is lower..?
print(np.sum(np.abs(1e6*np.matmul(test_1[0][0],test_1[1][0])-V1)))
print(np.sum(np.abs(1e6*np.matmul(test_1[0][3],test_1[1][3])-V1)))

[[ 56.65873321 159.46354873  98.60171215 ... 224.58736809 367.95271392
  244.77677242]
 [224.85811856 139.07386407 107.6076536  ...  27.3566821   16.78638887
   49.03112452]
 [ 12.92686767  79.01089751  21.84342772 ...  26.66812826  32.43638249
   16.0751336 ]
 ...
 [ 27.8280846   99.37041231  53.30303765 ...  49.8028319   33.79935804
   47.83389497]
 [ 63.77903127  37.83372974 266.58249998 ... 173.03054663  19.12150188
  187.37186526]
 [ 18.77349248 186.04480332   7.33016943 ...  46.89195823  94.75829256
  165.35965972]]
[[3.22103399e+01 1.12707707e+02 3.00656263e+01 ... 2.63876548e+02
  2.58453411e+02 2.19055996e+02]
 [1.82471857e+02 1.11292983e+02 7.69338899e+01 ... 6.72551920e+00
  1.01377791e+00 2.96439560e+01]
 [4.07826769e-01 7.94530940e+01 1.89058207e+01 ... 7.06620387e+00
  1.30877387e+01 4.33924968e+01]
 ...
 [1.72659224e+01 1.01211371e+02 4.91685123e+01 ... 4.94505033e+01
  2.44564729e+01 5.67509105e+01]
 [7.69758104e+01 3.62898891e+01 2.68069329e+02 ... 1.73357263e+02
  1.3

In [7]:
# So for these r=3 definetly looks better...
print(np.abs(1e6*np.matmul(test_2[0][0],test_2[1][0])-V2))
print(np.abs(1e6*np.matmul(test_2[0][3],test_2[1][3])-V2))
# But the sum for r=6 is lower..?
print(np.sum(np.abs(1e6*np.matmul(test_2[0][0],test_2[1][0])-V2)))
print(np.sum(np.abs(1e6*np.matmul(test_2[0][3],test_2[1][3])-V2)))

[[1.40089814e+02 6.50169178e+01 7.15459454e+01 ... 2.46652019e+02
  2.38939769e+00 2.32614623e+00]
 [7.76460323e+01 2.95130742e+02 1.62403995e+01 ... 8.02502464e+01
  7.62607005e-01 4.10076680e+00]
 [1.45940309e+02 6.04858010e+01 4.59380670e+01 ... 5.96137126e+01
  3.63200645e+00 3.63370689e-01]
 ...
 [1.08935616e+02 3.84102467e+01 7.24613398e+00 ... 9.39403623e+01
  4.76537582e+00 2.41210233e+00]
 [1.14949567e+01 2.13125563e+02 7.44350275e+01 ... 6.81051373e+00
  1.44097956e+02 3.18993334e+01]
 [4.91409088e+00 2.94770400e-01 1.41635452e+02 ... 1.54880173e+01
  2.01129052e+02 1.76129516e+02]]
[[1.42046241e+02 6.20816142e+01 5.92747778e+01 ... 2.42249662e+02
  4.02486782e+00 2.78866086e+00]
 [7.39488493e+01 2.97718446e+02 9.57823047e+00 ... 7.68583841e+01
  2.10824106e-01 2.43707337e+00]
 [1.43398136e+02 5.92380215e+01 4.54289660e+01 ... 5.33222433e+01
  4.41616427e+00 2.02976474e-02]
 ...
 [1.11805776e+02 3.97530318e+01 1.72046902e+01 ... 9.15866211e+01
  1.51656800e+00 5.16643820e-01]

In [8]:
# So for these r=3 definetly looks better...
print(np.abs(1e6*np.matmul(test_3[0][0],test_3[1][0])-V3))
print(np.abs(1e6*np.matmul(test_3[0][3],test_3[1][3])-V3))
# But the sum for r=6 is lower..?
print(np.sum(np.abs(1e6*np.matmul(test_3[0][0],test_3[1][0])-V3)))
print(np.sum(np.abs(1e6*np.matmul(test_3[0][3],test_3[1][3])-V3)))

[[8.74932347e-02 1.15486836e+02 2.63317809e+01 ... 6.65507897e+01
  4.19960444e+00 1.43177219e+02]
 [5.68506960e+00 2.97440972e+01 6.07459612e+01 ... 7.77442340e+01
  1.24609897e+02 1.10284586e+02]
 [1.89544334e+02 1.50262441e+02 1.71552322e+01 ... 2.58381280e+02
  1.63452964e+01 1.53543139e+02]
 ...
 [1.28973604e+02 7.26695183e+01 2.35817094e+01 ... 1.47177104e+01
  5.67470016e+01 2.80776154e+02]
 [6.55448760e+01 8.32769953e+01 4.35209957e+01 ... 1.07636921e+02
  1.88683603e+01 2.20886264e+02]
 [1.31260166e+01 6.22461211e+00 2.51058928e+01 ... 2.87346137e+01
  1.83258745e+01 3.75509813e+01]]
[[  4.68776044 105.62675421  24.38912042 ...  75.34191368  30.95062484
  131.82206335]
 [ 10.87081712  49.33639147  65.90392838 ...  67.81276738  88.90996274
   64.27319876]
 [184.60983925 132.6310775    9.16266774 ... 246.47319008   9.3522187
  200.79437298]
 ...
 [135.9384065   71.90094401  24.25513978 ...   7.03451562  25.40329429
  278.05667896]
 [ 49.5810995   49.61190395   8.57741209 ...  63

In [9]:
print(test[2].index(max(test[2])))

NameError: name 'test' is not defined

In [None]:
print(W_true.sum(axis=1))
print(test[0].sum(axis=1))
print(test[2])
# print(np.mean(test[1], axis=1))

In [None]:
print(H_true)
print(test[1])

In [None]:
a = 1e6*np.matmul(test[0], test[1])
b = V
print(a)
print(b)
print(a-b)

In [None]:
# broadcast testing
a = np.array([[1,2,3],[4,5,6]])
a.sum(axis=1)