Skip to content

Commit

Permalink
make the dim spec of logodds and related docs more rigorous in parall…
Browse files Browse the repository at this point in the history
…el computing scripts; this change does not affect existing results though
  • Loading branch information
xiangzhu committed May 10, 2018
1 parent 2a84e63 commit f5c8380
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src_vb/rss_varbvsr_bigmem_squarem.m
Expand Up @@ -6,7 +6,7 @@
% INPUT:
% file: the path of mat file that contains cell arrays of betahat, se and SiRiS, string
% sigb: the prior SD of the regression coefficients (if included), scalar
% logodds: the prior log-odds (i.e. log(prior PIP/(1-prior PIP))) of inclusion for each SNP, p by 1
% logodds: the prior log-odds (i.e. log(prior PIP/(1-prior PIP))) of inclusion for each SNP, p by 1 or scalar
% options: user-specified behaviour of the algorithm, structure
% - max_walltime: scalar, the maximum wall time (unit: seconds) for this program
% - tolerance: scalar, convergence tolerance
Expand Down Expand Up @@ -55,6 +55,11 @@
% Get the number of analyzed SNPs in the whole genome (p).
p = length(cell2mat(betahat));

% Convert logodds to a vector if the input is a scalar.
if isscalar(logodds)
logodds = repmat(logodds,p,1);
end

% Set initial estimates of variational parameters.
if isfield(options,'alpha')
alpha = double(options.alpha(:));
Expand Down Expand Up @@ -92,13 +97,15 @@
SiRiSr_cell = cell(C, 1);
q_cell = cell(C, 1);
sesquare_cell = cell(C, 1);
logodds_cell = cell(C, 1);
sigb_square = sigb * sigb;

for c = 1:C
chr_start = chrpar(c,1);
chr_end = chrpar(c,2);
alpha_cell{c,1} = alpha(chr_start:chr_end);
mu_cell{c,1} = mu(chr_start:chr_end);
logodds_cell{c, 1} = logodds(chr_start:chr_end);
end

% Compute a few useful quantities for the main loop.
Expand Down Expand Up @@ -126,7 +133,7 @@
parfor c = 1:C
r = alpha_cell{c,1} .* mu_cell{c,1};

lnZ_cell(c) = (q_cell{c,1})'*r - 0.5*r'*SiRiSr_cell{c,1} + intgamma(logodds,alpha_cell{c,1});
lnZ_cell(c) = (q_cell{c,1})'*r - 0.5*r'*SiRiSr_cell{c,1} + intgamma(logodds_cell{c,1},alpha_cell{c,1});
lnZ_cell(c) = lnZ_cell(c) - 0.5*(1./sesquare_cell{c,1})'*betavar(alpha_cell{c,1},mu_cell{c,1},s_cell{c,1});
lnZ_cell(c) = lnZ_cell(c) + intklbeta_rssbvsr(alpha_cell{c,1},mu_cell{c,1},s_cell{c,1},sigb_square);
end
Expand Down Expand Up @@ -188,11 +195,11 @@
end

% Run the first fixed-point mapping step (line 1 of Table 1).
[alpha_tmp1,mu_tmp1,SiRiSr_tmp1] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds,betahat{c,1},se{c,1}, ...
[alpha_tmp1,mu_tmp1,SiRiSr_tmp1] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds_cell{c,1},betahat{c,1},se{c,1}, ...
alpha0_cell{c,1},mu0_cell{c,1},SiRiSr0_cell{c,1},I);

% Run the second fixed-point mapping step (line 2 of Table 1).
[alpha_tmp2,mu_tmp2,SiRiSr_tmp2] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds,betahat{c,1},se{c,1}, ...
[alpha_tmp2,mu_tmp2,SiRiSr_tmp2] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds_cell{c,1},betahat{c,1},se{c,1}, ...
alpha_tmp1, mu_tmp1, SiRiSr_tmp1, I);

% Compute the step length (line 3-4 of Table 1).
Expand Down Expand Up @@ -258,7 +265,7 @@
mu_tmp3 = mu_tmp1_cell{c,1};
SiRiSr_tmp3 = SiRiSr_tmp1_cell{c,1};

[alpha_tmp,mu_tmp,SiRiSr_tmp] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds,betahat{c,1},se{c,1}, ...
[alpha_tmp,mu_tmp,SiRiSr_tmp] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds_cell{c,1},betahat{c,1},se{c,1}, ...
alpha_tmp3,mu_tmp3,SiRiSr_tmp3,I);
alpha_cell{c,1} = alpha_tmp;
mu_cell{c,1} = mu_tmp;
Expand All @@ -268,7 +275,7 @@
SiRiSr_cell{c,1} = SiRiSr_tmp;

% Compute the lower bound to the marginal log-likelihood of Chr. c.
lnZ_cell(c) = (q_cell{c,1})'*r - 0.5*r'*SiRiSr_cell{c,1} + intgamma(logodds,alpha_cell{c,1});
lnZ_cell(c) = (q_cell{c,1})'*r - 0.5*r'*SiRiSr_cell{c,1} + intgamma(logodds_cell{c,1},alpha_cell{c,1});
lnZ_cell(c) = lnZ_cell(c) - 0.5*(1./sesquare_cell{c,1})'*betavar(alpha_cell{c,1},mu_cell{c,1},s_cell{c,1});
lnZ_cell(c) = lnZ_cell(c) + intklbeta_rssbvsr(alpha_cell{c,1},mu_cell{c,1},s_cell{c,1},sigb_square);

Expand Down Expand Up @@ -308,14 +315,14 @@
mu_tmp3 = mu0_cell{c,1} - 2*mtp*mu_r_cell{c,1} + (mtp^2)*mu_v_cell{c,1};
SiRiSr_tmp3 = full(SiRiS_tmp * (alpha_tmp3 .* mu_tmp3));

[alpha_tmp,mu_tmp,SiRiSr_tmp] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds,betahat{c,1},se{c,1}, ...
[alpha_tmp,mu_tmp,SiRiSr_tmp] = rss_varbvsr_update(SiRiS_tmp,sigb,logodds_cell{c,1},betahat{c,1},se{c,1}, ...
alpha_tmp3, mu_tmp3, SiRiSr_tmp3, I);
alpha_cell{c,1} = alpha_tmp;
mu_cell{c,1} = mu_tmp;
r = alpha_tmp .* mu_tmp;
SiRiSr_cell{c,1} = SiRiSr_tmp;

lnZ_cell(c) = (q_cell{c,1})'*r - 0.5*r'*SiRiSr_cell{c,1} + intgamma(logodds,alpha_cell{c,1});
lnZ_cell(c) = (q_cell{c,1})'*r - 0.5*r'*SiRiSr_cell{c,1} + intgamma(logodds_cell{c,1},alpha_cell{c,1});
lnZ_cell(c) = lnZ_cell(c) - 0.5*(1./sesquare_cell{c,1})'*betavar(alpha_cell{c,1},mu_cell{c,1},s_cell{c,1});
lnZ_cell(c) = lnZ_cell(c) + intklbeta_rssbvsr(alpha_cell{c,1},mu_cell{c,1},s_cell{c,1},sigb_square);
end
Expand Down

0 comments on commit f5c8380

Please sign in to comment.