# $$\textbf{Assigning loci to low-dimensional phenotypes using sparse regression} $$
The algorithm takes a matrix of additive effects, $F$, as input. Here $F \in \mathbb{R}^{E\times L}$, where $L$ is the
number of loci and $E$ is the number of environments. 


We'd like to split $F$ as $F = WM + b$, where $W \in \mathbb{R}^{E\times K}$, $M \in \mathbb{R}^{K\times L}$, $b \in \mathbb{R}^{1\times L}$  and $K$ is the number
of lower-dimensional phenotypes. The cost function approximately minimized here is 
\begin{align}
\mathcal{C}(W,M,b) = ||F - (WM + b)||^2_F + \lambda_W \rho_1(W) + \lambda_M \rho_2(M),
\end{align}
where $W$ and $M$ are regularized by $\rho_1$ and $\rho_2$ respectively. The regularizers have to be chosen
such that the symmetry $W \to WB, M \to B^{-1}M$ is broken for an arbitrary invertible matrix $B$. Otherwise the procedure will lead to multiple solutions. 
To optimize for $W$, $M$ and $b$, we 
use an alternating minimization algorithm, where we 1) optimize $W$  fixing $M$, 2) optimize $M,b$
fixing $W$, 3) repeat 1 and 2 until convergence. For each step, we use standard linear regression methods from SciPy.

# $$\textbf{Application on synthetic data (independent)} $$

The additive effects matrix, $F$, is computed from a generative model of $W$ and $M$. The elements of $W$ and $M$ are drawn i.i.d. and have probability $p$ of being non-zero. If non-zero, the values are drawn from a standard normal distribution. See the Methods section of the manuscript for more details.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import pickle
sys.path.insert(0, 'utils/')
from factorizer import *
import ssd
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import os

%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
#create folders figures/syn_ind if it does not exist
#download pickled_factorizers from repo
#the pickled factorizers can also be generated by running 

# Load pickled data and set parameters used for decompositions

In [None]:
fig_loc = "figures/syn_ind" 

#create directory if it does not exist
try:
    os.mkdir(fig_loc)
except:
    print(f"error: already have directory {fig_loc} or prefix directory doesn't exist")
    
mode = "syn_ind"
directory = "pickled_factorizers"
fcts= pickle.load(open(f"{directory}/{mode}","rb"))
triples = pickle.load(open(f"{directory}/triples_{mode}","rb"))

In [None]:
lamb2_range = 10**(np.linspace(np.log10(1e-3),np.log10(1.5),25))
lamb1_range = 10**(np.linspace(np.log10(1e-4),np.log10(1e-2),25))
lamb1_fixed = [1e-4]
lamb2_fixed = [1e-3]

svd_k = 6
m_sparse = 0.2
m_not = 1.0
w_sparse = .2
w_not = 1.0
mws = 0

## Plot rotation tests

In [None]:
yrange = (0,.16)
xrange = (1,10)
figsize = (3,5)
for m in [m_sparse,m_not]:
    for w in [w_sparse, w_not]:
        print(f"\nm = {m}, w = {w}")
        name = (mode, m,w,mws)
        fct = fcts[(name ,None,None)]
        K = fct.computed_params(printout = False)[0][1]
        print(K)
        fct_rots = [fcts[(name, None, i)] for i in [3,4,5]] 
        rotate = "loci"
        #plot_rotation_test(fct, fct_rot, K, lamb1_fixed, lamb2_range, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None)
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_fixed, lamb2_range, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None, figsize = figsize, xticks = [3,6,9], oc='tab:red', rotc='dimgrey')
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_fixed, lamb2_range, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = f"{fig_loc}/loci_rot_test_{m}_{w}.svg", figsize = figsize, xticks = [3,6,9], oc='tab:red', rotc='dimgrey', labels= False, legend = False)

        fct_rots = [fcts[(name, i, None)] for i in [0,1,2]] 
        rotate = "env"
        #plot_rotation_test(fct, fct_rot, K, lamb1_range, lamb2_fixed, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None)
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_range, lamb2_fixed, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None, figsize = figsize,  xticks = [3,6,9], oc='tab:blue', rotc='dimgrey')
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_range, lamb2_fixed, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = f"{fig_loc}/trait_rot_test_{m}_{w}.svg", figsize = figsize,  xticks = [3,6,9], oc='tab:blue', rotc='dimgrey', labels= False, legend = False)



In [None]:
circled_points = {}
circled_points[(.2,.2)] = [(1.7, 1.6)]
circled_points[(.2,1.0)] = [(1.7, 5.5)]
circled_points[(1.0, .2)] = [(5.3, 1.5)]
circled_points[(1.0, 1.0)] = [(5.3, 5.3)]

In [None]:


for m in [m_sparse,m_not]:
    for w in [w_sparse, w_not]:
        print(f"\nm = {m}, w = {w}")
        name = (mode, m,w,mws)
        fct = fcts[(name ,None,None)]
        K = fct.computed_params(printout = False)[0][1]
        print(K)
        scatter = True
        ve = .25
        minx = 1-ve
        maxx = 6+ve
        miny = 1-ve
        maxy = 6+ve
        vmax = 0.11
        vmin = 0.065

        plot_solution_space(fct, K, lamb1_range,lamb2_range[6:], minx, maxx, miny, maxy, vmax, vmin, scatter=True, scatter_restricted=True, restrict_in_range = False, k_labeled_points = None, circled_points = circled_points[(m,w)], save_name = None)
        plot_solution_space(fct, K, lamb1_range,lamb2_range[6:], minx, maxx, miny, maxy, vmax, vmin, scatter=True, scatter_restricted=True, restrict_in_range = False, k_labeled_points = None, circled_points = circled_points[(m,w)], labels= False, legend = False, save_name = f"{fig_loc}/solution_space_{m}_{w}.svg")


In [None]:
mws = 0
k = 6
m_sparse, m_not = 0.2, 1.0
w_sparse, w_not = 0.2, 1.0
fs = 16
k=6
for m in [m_sparse,m_not]:
    for w in [w_sparse,w_not]:
        print(f"\nm = {m}, w = {w}")
        name = (mode, m,w,mws)
        pt = circled_points[(m,w)][0]
        bar_plot_M_W_error(fcts, name, pt,k,colors = ['tab:red', 'dimgrey' ,'tab:blue', 'dimgrey']) 
        bar_plot_M_W_error(fcts, name, pt,k,colors = ['tab:red', 'dimgrey' ,'tab:blue', 'dimgrey'], labels= False, save_name = f"{fig_loc}/bar_plot_{m}_{w}.svg") 
        plt.show()

## Figures below not for paper; just to sanity check and explain imperfect case of SSD W

In [None]:
mws = 0
k = 6
m_sparse, m_not = 0.2, 1.0
w_sparse, w_not = 0.2, 1.0
fs = 16
k=6
for m in [m_sparse,m_not]:
    for w in [w_sparse,w_not]:
        print(f"\nm = {m}, w = {w}")
        name = (mode, m,w,mws)
        pt = circled_points[(m,w)][0]
        compare_Ws(fcts, name, pt)
        plt.show()

## For different levels of sparsity

In [None]:
fig_loc = "figures/syn_ind"
mode = "syn_ind_range"
directory = "pickled_factorizers"

fcts= pickle.load(open(f"{directory}/{mode}","rb"))
triples = pickle.load(open(f"{directory}/triples_{mode}","rb"))

lamb2_range = 10**(np.linspace(np.log10(1e-3),np.log10(1.5),25))
lamb1_range = 10**(np.linspace(np.log10(1e-4),np.log10(1e-2),25))
lamb1_fixed = [1e-4]
lamb2_fixed = [1e-3]

svd_k = 6
mws = 0

mode = "syn_ind"

In [None]:
yrange = (0,.16)
xrange = (1,10)
figsize = (3,5)
for m in [0.2, 0.4, 0.6, 0.8]:
    for w in [1.0]:
        print(f"\nm = {m}, w = {w}")
        name = (mode, m,w,mws)
        fct = fcts[(name ,None,None)]
        K = fct.computed_params(printout = False)[0][1]
        print(K)
        fct_rots = [fcts[(name, None, i)] for i in [3,4,5]] 
        rotate = "loci"
        #plot_rotation_test(fct, fct_rot, K, lamb1_fixed, lamb2_range, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None)
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_fixed, lamb2_range, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None, figsize = figsize, xticks = [3,6,9], oc='tab:red', rotc='dimgrey')
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_fixed, lamb2_range, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = f"{fig_loc}/supp_loci_rot_test_{m}_{w}.svg", figsize = figsize, xticks = [3,6,9], oc='tab:red', rotc='dimgrey', labels= False, legend = False)


In [None]:
yrange = (0,.16)
xrange = (1,10)
figsize = (3,5)
for w in [0.2, 0.4, 0.6, 0.8]:
    for m in [1.0]:
        print(f"\nm = {m}, w = {w}")
        name = (mode, m,w,mws)
        fct = fcts[(name ,None,None)]
        K = fct.computed_params(printout = False)[0][1]
        print(K)
        
        fct_rots = [fcts[(name, i, None)] for i in [0,1,2]] 
        rotate = "env"
        #plot_rotation_test(fct, fct_rot, K, lamb1_range, lamb2_fixed, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None)
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_range, lamb2_fixed, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = None, figsize = figsize,  xticks = [3,6,9], oc='tab:blue', rotc='dimgrey')
        plot_rotation_test_w_error(fct, fct_rots, K, lamb1_range, lamb2_fixed, xrange, yrange, rotate, fs = 16, svd_k=K, true_line = True, save_name = f"{fig_loc}/supp_trait_rot_test_{m}_{w}.svg", figsize = figsize,  xticks = [3,6,9], oc='tab:blue', rotc='dimgrey', labels= False, legend = False)
