In [1]:
import numpy as np
from scipy import linalg, stats
from sklearn import linear_model
from sklearn import decomposition
import time

In [2]:
class dataset:
    def __init__(self, n, m, k, h2):
        self.n = n 
        self.m = m 
        self.k = k 
        self.h2 = h2

        self.simGeno()
        self.simEffects()
        self.simPheno()

    def simGeno(self):
        geno = np.zeros([self.n, self.m])
        for i in range(self.m):
            p = np.random.beta(2, 2)
            snps = np.random.binomial(2, p, self.n)
            geno.T[i] = snps - (2*p)/np.sqrt(2*p*(1-p))

        inter = linalg.khatri_rao(geno.T, geno.T).T
        self.geno, self.inter = geno, inter

    def simEffects(self):
        u = np.random.normal(0, 1, self.m * self.k).reshape(self.m, -1) 
        v = np.random.normal(0, 1, self.m * self.k).reshape(self.m, -1) 
        omega = u @ v.T + v @ u.T 

        self.uv = np.column_stack((u, v)) 
        self.vu = np.column_stack((v, u)) 

        self.u = u 
        self.v = v 
        self.omegaMat = omega
        self.omega = omega.reshape(1, -1)[0]
        
    
    def simPheno(self):
        mean = np.matmul(self.inter, self.omega)
        var = np.var(mean) * (1 - self.h2)/self.h2
        sd = np.sqrt(var)
        noise = np.random.normal(0, sd, self.n)
        self.pheno = mean + noise

In [3]:
class decomp:
    def __init__(self):
        self.data = None
    def simData(self, n, m, k, h2):
        self.data = dataset(n, m, k, h2)
    def fitMarginal(self):
        lm = linear_model.LinearRegression(fit_intercept=True)
        omegaHat = list()
        for i in range(self.data.m ** 2):
            lm.fit(self.data.inter.T[i].reshape(-1, 1), self.data.pheno)
            omegaHat.append(lm.coef_[0])

        omegaHat = np.array(omegaHat)
        omegaHatMat = omegaHat.reshape(self.data.m, -1)

        self.omegaHat, self.omegaHatMat = omegaHat, omegaHatMat

    def fitRidge(self):
        lm = linear_model.Ridge()
        lm.fit(self.data.inter, self.data.pheno)
        omegaHat = lm.coef_
        omegaHatMat = omegaHat.reshape(self.data.m, -1)

        self.omegaHat, self.omegaHatMat = omegaHat, omegaHatMat



    def fitSVD(self):
        rank = self.data.k * 2
        lm = linear_model.LinearRegression(fit_intercept=True)

        A, singular, B = linalg.svd(self.omegaHatMat)

        A = A[:, :rank]
        singular = singular[:rank]
        B = B[:, :rank]


        lm.fit(self.data.uv, A)
        transform = lm.coef_.T
        
        self.u = (A @ linalg.inv(transform))[:,:self.data.k]
        self.v = (A @ linalg.inv(transform))[:,self.data.k:]

    def symmetricDecomp(self):
        m = self.data.m
        
        thresh = 0.0001
                
        u = np.random.rand(m, 1)
        v = np.random.rand(m, 1)
        
        uprev = np.copy(u)
        vprev = np.copy(v)
        
        i = 0
        
        while(True):
            if i % 2 == 0:
                u = linalg.inv(v @ v.T + (v.T @ v)*np.eye(m)) @ self.omegaHatMat @ v
            else:
                v = linalg.inv(u @ u.T + (u.T @ u)*np.eye(m)) @ self.omegaHatMat @ u
            
            udiff = linalg.norm(u - uprev)
            vdiff = linalg.norm(v - vprev)
            
            if udiff < thresh and vdiff < thresh:
                break
            else:
                uprev = np.copy(u)
                vprev = np.copy(v)
            i += 1

        self.u = u
        self.v = v        

    def evalNorm(self):
        
        trueU = self.data.u.reshape(1, -1)[0]
        trueV = self.data.v.reshape(1, -1)[0]
        
        estU = self.u.reshape(1, -1)[0]
        estV = self.v.reshape(1, -1)[0]
        
        original = ("original", 
                    stats.pearsonr(estU, trueU)[0] ** 2,
                    stats.pearsonr(estV, trueV)[0] ** 2)
        switched = ("switched", 
                    stats.pearsonr(estV, trueU)[0] ** 2, 
                    stats.pearsonr(estU, trueV)[0] ** 2)
        
        
        if original[1] + original[2] > switched[1] + switched[2]:
            return original
        else: 
            return switched


In [None]:
data = list()
for m in range(10, 40 + 1, 10):
    print(m)
    for n in range(10, 60 + 1, 5):
        
        mfloat = m/10
        nfloat = n/10
        
        benchmark = decomp()
        decompAcc = list()
        decompTime = list()
        benchmark.simData(round(10 ** nfloat), round(10 ** mfloat), 1, 0.9)
        benchmark.fitRidge()
        
        for i in range(10):
            start = time.time()
            benchmark.symmetricDecomp()
            end = time.time()
            _, ufit, vfit = benchmark.evalNorm()
            decompTime.append(end - start)
            decompAcc.append(ufit)
    
        data.append((m, n, np.mean(decompAcc), np.mean(decompTime)))

10
20


In [None]:
data[:5]

In [8]:
dataList = list()
for i in range(100):
    test = decomp()
    test.simData(1000, 10, 1, 0.9)
    test.fitMarginal()
    original = test.data.omegaMat
    
    for j in range(5):
        test.symmetricDecomp()
        _, ufit, vfit = test.evalNorm()
        reconstructed = test.u @ test.v.T + test.v @ test.u.T
        
        dataList.append((i, ufit, vfit, linalg.norm(reconstructed - original)))

In [9]:
dataList

[(0, 0.8732005336007931, 0.8150387562359245, 7.748754665245238),
 (0, 0.8731658270991238, 0.8149815657998305, 7.748966011320285),
 (0, -0.873196914254487, -0.8150347692542804, 7.7487689949954826),
 (0, 0.8731800455705931, 0.8150161874143138, 7.748835756453184),
 (0, 0.8731745376493459, 0.8149922155962024, 7.748927461573315),
 (1, -0.7972197642866155, -0.4561050909791109, 51.47689072799281),
 (1, -0.7972171330225017, -0.4561027512274884, 51.47688900761164),
 (1, -0.7972015926630538, -0.45608567849467413, 51.47687836741048),
 (1, -0.7972271099653014, -0.4561130514968764, 51.47689583602731),
 (1, -0.7972204632813988, -0.45610584265720955, 51.476891206709624),
 (2, -0.6385180303472898, 0.23488951831502972, 184.47263049929168),
 (2, -0.6385179555117847, 0.23488966580769216, 184.47263048255675),
 (2, -0.6385196312137488, 0.23488668394698745, 184.47263073858556),
 (2, -0.6385174617333502, 0.23489052295478557, 184.47263041182848),
 (2, -0.6385268029118406, 0.23487398757071085, 184.472631711010

In [118]:
import csv

with open('./symdecomp.csv','w') as out:
    csv_out=csv.writer(out)
    csv_out.writerow(['iteration','ufit', 'vfit', 'error'])
    for row in dataList:
        csv_out.writerow(row)
