In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm as cm
from scipy.stats import binned_statistic
from scipy.stats import sem
import random
import copy
import warnings
import plot_utils as plu

## LD 
Throughout the notebook, LD is measured by r^2 (the classical correlation coefficient)

In [None]:
print("* Computing and plotting LD...")

#### Compute correlation between all pairs of SNPs for each generated/real dataset

In [None]:
# categ = infiles.keys()
# hcor_snp = dict()
# for i,cat in enumerate(categ):
#     print(cat)
#     with np.errstate(divide='ignore', invalid='ignore'): 
#         # Catch warnings due to fixed sites in dataset (the correlation value will be np.nan for pairs involving these sites)
#         hcor_snp[cat] = np.corrcoef(datasets[cat], rowvar=False)**2  # r2

### Plot LD or COV binned by distance between SNPs 
- LD is binned into 'nbins'
- that are cut on a logscale if logscale=True
- For regressions plot of non-binned LD, we randomly subsample 'nsamplesets' pairs for computation/visualization convenience


In [None]:
# _, region_len, snps_on_same_chrom = plu.get_dist(position_fname['Real'], region_len_only=True,  kept_preprocessing=keptsnpdic['Real'])

# nbins=50
# nsamplesets=10000
# logscale=True
# bins = nbins
# binsPerDist = nbins
# if logscale: binsPerDist = np.logspace(np.log(1), np.log(region_len), nbins)

In [None]:
# # Compute LD binned by distance
# # Take only sites that are SNPs in all datasets (intersect)
# # (eg intersection of SNPs in Real, SNPs in GAN, SNPs in RBM etc)
# # -> Makes sense only if there is a correspondence between sites

# if matching_SNPs:    
#     binnedLD = dict()
#     binnedPerDistLD = dict()
#     kept_snp = ~is_fixed
#     n_kept_snp = np.sum(kept_snp)
#     realdist = plu.get_dist(position_fname['Real'], kept_preprocessing=keptsnpdic['Real'], kept_snp=kept_snp)[0]
#     mat = hcor_snp['Real']
#     # filter and flatten
#     flatreal = (mat[np.ix_(kept_snp,kept_snp)])[np.triu_indices(n_kept_snp)]
#     isnanReal = np.isnan(flatreal)
#     i=1
#     realdist = plu.get_dist(position_fname['Real'], kept_preprocessing=keptsnpdic['Real'], kept_snp=kept_snp)[0]
#     hcor_snp_ref = hcor_snp.pop('Real')
#     fig, axs = plt.subplots(nrows=2, ncols=len(hcor_snp), 
#                         figsize = (len(hcor_snp)*5, 2*5), constrained_layout=True)
#     for ite, (cat, mat) in enumerate(hcor_snp.items()):
#         print(ite)
#         flathcor = (mat[np.ix_(kept_snp,kept_snp)])[np.triu_indices(n_kept_snp)]
#         isnan = np.isnan(flathcor)
#         curr_dist = plu.get_dist(position_fname[cat],  kept_preprocessing=keptsnpdic[cat], kept_snp=kept_snp)[0]

#         # For each dataset LD pairs are stratified by SNP distance and cut into 'nbins' bins
#         # bin per SNP distance
#         ld = binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic = 'mean', bins=binsPerDist)
#         with warnings.catch_warnings():
#             warnings.simplefilter("ignore", category=RuntimeWarning) # so that empty bins do not raise a warning
#             binnedPerDistLD[cat] = pd.DataFrame({'bin_edges':ld.bin_edges[:-1],
#                                           'LD': ld.statistic,
#                                           #'sd': binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic = 'std', bins=binsPerDist).statistic,
#                                           'sem': binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic = sem, bins=binsPerDist).statistic,
#                                           'cat': cat, 'logscale': logscale})

#         # For each dataset LD pairs are stratified by LD values in Real and cut into 'nbins' bins
#         # binnedLD contains the average, std of LD values in each bin
#         isnan = np.isnan(flathcor) |  np.isnan(flatreal) 
#         ld = binned_statistic(flatreal[~isnan], flathcor[~isnan], statistic = 'mean', bins=bins)
#         with warnings.catch_warnings():
#             warnings.simplefilter("ignore", category=RuntimeWarning) # so that empty bins do not raise a warning
#             binnedLD[cat] = pd.DataFrame({'bin_edges':ld.bin_edges[:-1],
#                                           'LD': ld.statistic,
#                                           'sd': binned_statistic(flatreal[~isnan], flathcor[~isnan], statistic = 'std', bins=bins).statistic,
#                                           'sem': binned_statistic(flatreal[~isnan], flathcor[~isnan], statistic = sem, bins=bins).statistic,
#                                           'cat': cat, 'logscale': logscale})

#         # Plotting quantiles ?
#         plu.plotregquant(x=flatreal, y=flathcor, 
#                          keys=['Truth',cat], statname='LD', col=colpal[cat], 
#                          step=0.0001,
#                          ax=axs[1][ite])
#         axs[1][ite].set_title(str(f'Quantiles LD {cat} vs Truth'))

#         if matching_SNPs:
#             # removing nan values and subsampling before doing the regression to have a reasonnable number of points
#             isnanInter = isnanReal | isnan
#             keepforplotreg = random.sample(list(np.where(~isnanInter)[0]), nsamplesets)
#             plu.plotreg(x=flatreal[keepforplotreg], y=flathcor[keepforplotreg], 
#                         keys=['Truth',cat], statname='LD', col=colpal[cat], 
#                         ax=axs[0][ite])
#             i+=1
#             axs[0][ite].set_title(f'LD {cat} vs Truth')
#     plt.savefig(outDir + "LD_generated_vs_real_intersectSNP.pdf", dpi=300, bbox_inches='tight')

In [None]:
# # Plot LD as a fonction of binned distances
# # except when SNPs are spread accross different chromosomes
# if matching_SNPs: #(position_fname['Real']!="1kg_real/805snps.legend"):
#     tmp_real = colpal.pop('Real')
#     print(colpal)
#     sns.set_palette(colpal.values())

#     sns.palplot(sns.color_palette())
#     plt.figure(figsize=(7,7))
#     # binnedPerDistLD.pop('Real')
#     print(binnedPerDistLD.keys())
#     for cat, bld in binnedPerDistLD.items():
#         plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=cat, alpha=.65,linewidth=3)
#     plt.title("Binned LD +/- 1 sem")
#     if (logscale): plt.xscale('log')   
#     #plt.yscale('log')
#     plt.xlabel("Distance between SNPs (bp) [Left bound of distance bin]")
#     plt.ylabel("Average LD in bin")
#     plt.legend()
#     plt.savefig(outDir + "correlation_vs_dist_intersectSNP.pdf", dpi=300, bbox_inches='tight')
#     colpal['Real'] = tmp_real

#     # Zoomed-in plot for short distances (1 to 1000 bp)
#     plt.figure(figsize=(7,7))
#     for cat, bld in binnedPerDistLD.items():
#         plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=cat, alpha=.65, linewidth=3)
#     plt.title("Binned LD +/- 1 sem (Zoom: 1–1000 bp)")
#     plt.xscale('log')
#     plt.xlim(1, 1000)  # Focus on 10^0 to 10^3
#     plt.xlabel("Distance between SNPs (bp) [Left bound of distance bin]")
#     plt.ylabel("Average LD in bin")
#     plt.legend()
#     plt.savefig(outDir + "correlation_vs_dist_zoomed_1_1000bp.pdf", dpi=300, bbox_inches='tight')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import binned_statistic, sem
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

def load_positions(legend_file):
    """Load SNP genomic positions from legend file (1D array of ints)."""
    df = pd.read_csv(legend_file, sep=" ")
    return df["position"].values.astype(int)

def compute_ld_r2(geno):
    """Compute r² LD matrix from genotype matrix (n_samples × n_snps)."""
    with np.errstate(divide="ignore", invalid="ignore"):
        return np.corrcoef(geno, rowvar=False)**2

def flatten_upper(mat):
    """Return flattened upper triangle of a square matrix (no diagonal)."""
    iu = np.triu_indices(mat.shape[0], k=1)
    return mat[iu]

def compute_pairwise_distances(positions):
    """Compute pairwise SNP distances as flattened upper triangle."""
    diff = np.abs(positions[:, None] - positions[None, :])
    iu = np.triu_indices(len(positions), k=1)
    return diff[iu]

def compute_binned_ld(dists, ld, nbins=50, logscale=True):
    """Bin LD by distance and compute mean + SEM."""
    valid = ~np.isnan(ld)
    d = dists[valid]
    l = ld[valid]

    if logscale:
        bins = np.logspace(np.log10(max(1, d.min())), np.log10(d.max()), nbins)
    else:
        bins = nbins

    ld_mean = binned_statistic(d, l, statistic="mean", bins=bins)
    ld_sem  = binned_statistic(d, l, statistic=sem, bins=bins)

    return pd.DataFrame({
        "bin_left": ld_mean.bin_edges[:-1],
        "LD_mean": ld_mean.statistic,
        "LD_sem": ld_sem.statistic
    })

# --- small helper: safe color lookup ---
def get_color(name, colpal, default="#333333"):
    """Return color for name from colpal, or default if not present."""
    return colpal.get(name, default)

def plot_ld_decay_with_colors(binned, colpal, outfile, logscale=True, zoom=None):
    plt.figure(figsize=(7, 7))
    for label in plot_order:
        if label not in binned:
            continue
        df = binned[label]
        color = get_color(label, colpal)
        plt.errorbar(df.bin_left, df.LD_mean, df.LD_sem,
                     label=label, alpha=0.8, linewidth=3, color=color)

    if logscale:
        plt.xscale("log")
    if zoom:
        plt.xlim(*zoom)

    plt.xlabel("Distance between SNPs (bp)")
    plt.ylabel("Average LD (r²)")
    # plt.title("LD decay (+/- 1 SEM)")

    # Legend with robust linewidth handling
    leg = plt.legend()
    handles = getattr(leg, "legendHandles", None)
    if handles is None:
        handles = getattr(leg, "legend_handles", [])
    for h in handles:
        try:
            h.set_linewidth(3.0)
        except Exception:
            pass

    plt.tight_layout()
    plt.savefig(outfile, dpi=300)
    plt.close()

# Step 0: load positions and filter out fixed SNPs
positions_all = load_positions(realposfname)
kept_snp_mask = ~is_fixed
positions = positions_all[kept_snp_mask]

print(f"Number of unfixed SNPs: {len(positions)}")

# Step 1: compute LD (filtered)
ld_matrices = {
    name: compute_ld_r2(g[:, kept_snp_mask])
    for name, g in datasets.items()
}

# Step 2: flatten LD
flat_ld = {name: flatten_upper(ld_matrices[name]) for name in ld_matrices}

# Step 3: compute pairwise SNP distances with same filter
pairwise_distances = compute_pairwise_distances(positions)

# Step 4: compute LD decay curves
binned_ld = {
    name: compute_binned_ld(pairwise_distances, flat_ld[name], nbins=20)
    for name in ld_matrices
}

# Exclude Real from plotted lines (but keep it in binned_ld for reference)
binned_ld_no_real = {k: v for k, v in binned_ld.items() if k != "Real"}

# Create a deterministic plot order (optional: preserve the original order in infiles.keys())
plot_order = [k for k in infiles.keys() if (k in binned_ld_no_real)]
# if some keys in binned_ld_no_real are not in infiles, append them
for k in binned_ld_no_real:
    if k not in plot_order:
        plot_order.append(k)

# Now call the plotting function (Real excluded)
plot_ld_decay_with_colors(
    binned_ld_no_real,
    colpal,
    outfile=outDir + "LD_decay.pdf",
    logscale=True
)

plot_ld_decay_with_colors(
    binned_ld_no_real,
    colpal,
    outfile=outDir + "LD_decay_zoom_1_1000bp.pdf",
    logscale=True,
    zoom=(1, 1000)
)

In [None]:
# For each dataset LD pairs were stratified by LD values in Real, cut into nbins bins
# binnedLD contains the average LD in each bin
# Plot generated average LD as a function of the real average LD in the bins
# if matching_SNPs:
#     plt.figure(figsize=(10,10))
#     for cat, bld in binnedLD.items():
#         plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=cat, alpha=1, marker='o')
#     plt.title("Binned LD +/- 1 sem")
#     #if (logscale): plt.xscale('log')
#     plt.xlabel("Bins (LD in Real)")
#     plt.ylabel("Average LD in bin")
#     plt.legend()
#     plt.savefig(outDir+'LD_{}bins_{}fixedremoved.pdf'.format(nbins,'logdist_' if logscale else ''))

#### Plotting LD (block) matrix
- LD is r2
- if mirror=True plot the symmetrical matrix for each category
- if mirror=False plot the Generated LD in upper triangle versus the Real LD in lower triangle
- if diff=True plot the difference between generated and true LD values, else plot regular LD values

In [None]:
# print("* Plotting LD block matrices...")

# # Set edges of the region for which to plot LD block matrix (l=0, f='end') for full region
# # not used as for now apart from the filename
# l_bound=None
# r_bound=None

In [None]:

# # mirror (bool): plot symmetrical matrix or generated vs real?
# # diff (bool): plot LD values or generated minus real ?

# # (mirror=False, diff=False) = upper triangle = generated LD, lower triangle = real LD, raw values from 0 to 1.

# hcor_snp['Real'] = hcor_snp_ref
# print(hcor_snp.keys())
# # for snpcode in ("fullSNP", "intersectSNP"):
# #     for (mirror, diff) in ((True,False),(True,True), (False,False)):
# hcor_snp.pop('WGAN')
# hcor_snp.pop('RBM')
# hcor_snp.pop('Indep')
# hcor_snp.pop('Markov')
# hcor_snp.pop('HMM')
# hcor_snp.pop('Truth')

# for snpcode in ("fullSNP", "intersectSNP"):
#     for (mirror, diff) in ((True,False),(True,True),(False,False)):
#         if (not matching_SNPs) and (snpcode=="fullSNP") and (diff or not mirror):
#             print(f'Warning: not plotting LD for {snpcode} mirror={mirror} diff={diff}',
#                   ' because SNP have no one-to-one correspondence and matrices might have different sizes')
#         elif (matching_SNPs) and (snpcode=="fullSNP"):
#             pass
#         else:
#             outfilename = f"LD_HEATMAP_{snpcode}_bounds={l_bound}-{r_bound}_mirror={mirror}_diff={diff}.pdf" 
#             #print(outfilename)
#             fig = plt.figure(figsize=(7,7))
#             plu.plotLDblock(hcor_snp, 
#                             left=l_bound, right=r_bound,  # None, None -> takes all SNPs
#                             mirror=mirror, diff=diff,
#                             is_fixed=is_fixed, is_fixed_dic=is_fixed_dic,
#                             suptitle_kws={'t':outfilename}
#                            )
#             plt.tight_layout()
#             plt.savefig(outDir+outfilename, dpi=300, bbox_inches='tight')
#             plt.show()


In [None]:
print('****************************************************************\n*** Computation and plotting LD DONE. Figures saved in {} ***\n****************************************************************'.format(outDir))