In [None]:
import os 
from numba import njit
import malariagen_data
import zarr
import allel
import numpy as np
import polars as pl
from tqdm.notebook import tqdm

from bokeh.io import output_notebook # enables plot interface in J notebook
output_notebook(hide_banner=True)

### IBD segment detection with unphased data

The method is loosely based on the following studies:   
https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0034267  
https://www.sciencedirect.com/science/article/pii/S0002929720300549  

Essentially, we take two individuals and scan across their genomes. We aim to identify segments in so-called half-IBD (IBD1), where one chromosome is shared, and full-IBD or IBD2, where both chromosomes are shared. 

To detect IBD1 segment breakpoints, we record the position of any SNP where the two individuals are homozygous but different alleles - this has to be an IBD1 breakpoint.

To detect IBD2 segment breakpoints, we record the position of any SNP where alleles are different, i.e we are looking for stretches where everything is identical.

We then apply a size filter to remove small segments which are identical by state but probably not identical by descent. At the minute, this is somewhat arbitrary, and we need to develop sensible thresholds or decide if we want to try and do clever things to merge broken fragments. Although sequencing errors and mutations will currently break up IBD segments, for sibship inference at higher degrees of relatedness this will probably be OK.

In [None]:
def load_genotypes(self, contig, sample_sets, sample_query, site_mask='gamb_colu'):
    """
    Loads SNP calls and throws out invariant, singleton and doubleton sites, which
    massively speeds up IBD scanning algorithm.
    """
    
    print("Loading genotypes")
    ds_snps = self.snp_calls(
        region=contig, 
        sample_sets=sample_sets,
        sample_query=sample_query, 
        site_mask=site_mask
    )
    
    print("subset to 20 individuals for dev-ing")    
    geno = allel.GenotypeDaskArray(ds_snps['call_genotype'].values)
    
    sample_ids = ds_snps['sample_id'].values
    print("computing allele counts")
    ac = geno.count_alleles().compute()

    seg = ac.is_segregating()
    dbl = ac.is_doubleton(allele=1)
    dbl2 = ac.is_doubleton(allele=2)
    dbl3 = ac.is_doubleton(allele=3)
    dbl = np.logical_or(dbl, np.logical_or(dbl2, dbl3))
    sngl = ac.is_singleton(allele=1)
    sngl2 = ac.is_singleton(allele=2)
    sngl3 = ac.is_singleton(allele=3)
    
    mask = np.logical_and(seg, ~np.logical_or(dbl, np.logical_or(sngl, np.logical_or(sngl2, sngl3))))
    print(f"retaining {mask.sum()} variants and removing {(~mask).sum()}")
    
    geno = geno[mask].compute().values
    pos = ds_snps['variant_position'].values[mask]
    
    return geno, pos, sample_ids


@njit()
def ibdscans(geno, pos):
    """
    Scans across two genotype vectors, checking and recording SNPs which are either 
    are homozygous but different (IBD1 breakpoint) or have any alleles different (IBD2 breakpoint)
    """

    geno1 = geno[:, 0]
    geno2 = geno[:, 1]

    ibd1_breakpoints = []
    #ibd2_breakpoints = []
    # Scan right along genome, as long as two inds are not both homozygous but different
    for pos_idx in range(pos.shape[0]):
        
        gn1 = geno1[pos_idx]
        gn2 = geno2[pos_idx]        
        
#         # if both alleles are not identical, record the position
#         if not (gn1 == gn2).all() and (-1 not in gn1 and -1 not in gn2):
#             ibd2_breakpoints.append(pos_idx)
            # if both alleles are not identical and not missing, and both are homozygous, record the position
        if (gn1[0] == gn1[1]) and (gn2[0] == gn2[1]) and ((gn1 != gn2).all()) and (-1 not in gn1 and -1 not in gn2):
            ibd1_breakpoints.append(pos_idx)
   
    return np.array(ibd1_breakpoints)#, np.array(ibd2_breakpoints)

def detect_ibd_segments(self, contig, sample_sets, sample_query, cohort_id, out_dir, site_mask='gamb_colu', min_size=5000):
    """
    Loads genotype data, and runs IBD scan on each pair of individuals, writing IBD segments to 
    a .tsv file.
    """
    from itertools import combinations
    
    geno, pos, sample_ids = load_genotypes(
        self=self,
        contig=contig, 
        sample_sets=sample_sets, 
        sample_query=sample_query, 
        site_mask=site_mask
    )
    
    zarr.save(f"{out_dir}/{cohort_id}/breakpoints/{contig}/POS", pos)
    zarr.save(f"{out_dir}/{cohort_id}/sample_ids.zarr", sample_ids)
    for x,y in tqdm(combinations(range(geno.shape[1]), 2)):
        ibd1 = ibdscans(geno.take([x,y], axis=1), pos)
        zarr.save(f"{out_dir}/{cohort_id}/breakpoints/{contig}/{x}-{y}", ibd1=ibd1)
        
    return geno.shape[1]

In [None]:
def load_ibd_zarr_to_dataframe(self, contigs, out_dir, cohort_id, n_inds=20):
    from itertools import combinations
    dfs = []
    
    sample_ids = zarr.load(out_dir + f"/{cohort_id}/sample_ids.zarr")
    for contig in self.virtual_contigs + ('X', ):
        pos = zarr.load(out_dir + f"{cohort_id}/breakpoints/{contig}/POS/")
        
        for x,y in tqdm(combinations(range(n_inds), 2)):
            arr = zarr.load(out_dir + f"{cohort_id}/breakpoints/{contig}/{x}-{y}/")
            
            df1 = breakpoints_to_polars_df(array=arr['ibd1'], contig=contig, pos=pos, x=x, y=y, sample_ids=sample_ids)
            dfs.append(df1)

    return pl.concat(dfs)
        
def breakpoints_to_polars_df(array, pos, contig, x, y, sample_ids):
    
    pos_ibd = pos[array]
    n_snps = np.ediff1d(array)
    sizes = np.ediff1d(pos_ibd)
    df = pl.DataFrame({'size':sizes,
                      'n_snps':n_snps,
                      'start':pos_ibd[:-1], 
                      'end':pos_ibd[1:]})
    
    df = df.with_columns(pl.lit(contig).alias("contig"),
                         pl.lit(x).alias("idx1"),
                         pl.lit(y).alias("idx2"),
                         pl.lit(sample_ids[x]).alias("sample_id1"),
                         pl.lit(sample_ids[y]).alias("sample_id2"))

    return df

def fraction_genome_ibd(self, ibd_df, cohort_id=None, min_ibd1_snps=10_000, out_dir=None, contigs=('2RL', '3RL', 'X')):
    import polars as pl
    genome_size = np.sum([self.genome_sequence(contig).shape[0] for contig in contigs])
    
    ibd_df = ibd_df.filter(pl.col('n_snps') > min_ibd1_snps)
    
    df = ibd_df.groupby(['idx1', 'idx2']).agg(pl.sum("size")).with_columns(
        (pl.col("size") / genome_size).alias("ibd1_fraction")).sort(['idx1', 'idx2'])
    
    if out_dir:
        df.write_csv(out_dir + cohort_id + "/ibd1_fraction.tsv", separator="\t")
    
    return df 

In [None]:
def plot_ibd_segments_contig(
    self,
    df,
    contig, 
    min_ibd1_snps=10_000,
    show=False,
    width=None,
    height=400
):
    from itertools import combinations
    import bokeh.plotting as bkplt
    import bokeh.models as bkmod
    import bokeh.layouts as bklay
    
    df = df.filter(pl.col('contig') == contig).filter(pl.col('n_snps') > min_ibd1_snps)
        
    print("finding levels")
    max_inds = df['idx1'].max() + 1
    level_df = pl.DataFrame(list(combinations(range(max_inds), 2))).transpose()
    level_df = level_df.with_columns(pl.lit(np.linspace(0, 1, len(list(combinations(range(max_inds), 2))))).alias("level"))
    level_df = level_df.rename({'column_0':'idx1', 'column_1':'idx2'}).with_columns(pl.col("idx1").cast(pl.Int32),
                                                                                   pl.col("idx2").cast(pl.Int32))
    df = df.join(level_df, on=['idx1', 'idx2'])
    
#     print("colour mapping")
#     colour_mapping = {'half-ibd':'gray', 
#                       'full-ibd':'blue'}
#     colour = df['ibd_type'].apply(lambda x: colour_mapping[x])
    
    xs = [np.array([row[2], row[3]])  for row in df.iter_rows()]
    ys = [np.array([row[-1], row[-1]]) for row in df.iter_rows()]

    source = bkmod.ColumnDataSource(data={
        'index1': df['idx1'].to_numpy(),
        'index2': df['idx2'].to_numpy(),
         'sample_id1':df['sample_id1'].to_numpy(),
         'sample_id2':df['sample_id2'].to_numpy(),
        'chromosome': df['contig'].to_numpy(),
        'start': df['start'].to_numpy(),
        'end': df['end'].to_numpy(),
        'xs':xs,
        'ys':ys,
        'y_pos': df['level'].to_numpy(),
#         'colour': colour.to_numpy()
    })

    hover = bkmod.HoverTool(tooltips=[
            ("index1", '@index1'),
            ("index2", '@index2'),
             ("sample_id1", '@sample_id1'),
             ("sample_id2", '@sample_id2'),
            ("segment span", "@start{,} - @end{,}"),
        ])
        
    print("making figure")
    if not width:
        width = int(self.genome_sequence(contig).shape[0]/200000)
    fig1 = bkplt.figure(title=contig,
                        width=width,
                        height=500, 
                        tools="tap,box_zoom,xpan,xzoom_in,xzoom_out,xwheel_zoom,reset".split() + [hover],
                        toolbar_location='above', active_drag='xpan', active_scroll='xwheel_zoom')

    glyph = bkmod.MultiLine(xs='xs', ys='ys', line_color="grey", line_alpha=.8, line_width=2)
    fig1.add_glyph(source, glyph)

    fig1.x_range = bkmod.Range1d(0, self.genome_sequence(contig).shape[0], bounds='auto')
    fig1.y_range = bkmod.Range1d(0, 1, bounds='auto')
    fig1.x_range.max_interval = self.genome_sequence(contig).shape[0]
    fig1.yaxis.visible = False
    fig1.ygrid.visible = False
    _bokeh_style_genome_xaxis(fig1, contig)
    
    if show:
        bkplt.show(fig1)
    
    return fig1

def plot_ibd_segments(
        self,
        df, 
        out_dir,
        cohort_id,
        contigs=('2RL', '3RL', 'X'),
        min_ibd1_snps=10000,
        show=True,
        title=None,
    ):
    import bokeh.models as bkmod
    import bokeh.layouts as bklay
    import bokeh.plotting as bkplt
    
    figs = [
            plot_ibd_segments_contig(
                self=self,
                df=df,
                contig=contig,
                min_ibd1_snps=min_ibd1_snps,
                ) 
            for contig in tqdm(contigs)
            ]
    
    bkplt.output_file(filename=out_dir + cohort_id + "_segments.html", title=title)

    fig = bklay.gridplot(
        figs,
        ncols=len(contigs),
        toolbar_location="above",
        merge_tools=True,
    ) 
    
    bkplt.save(fig)
    if show:
        bkplt.show(fig)
    
def _bokeh_style_genome_xaxis(fig, contig):
    import bokeh.models as bkmod
    """Standard styling for X axis of genome plots."""
    fig.xaxis.axis_label = f"Contig {contig} position (bp)"
    fig.xaxis.ticker = bkmod.AdaptiveTicker(min_interval=1)
    fig.xaxis.minor_tick_line_color = None
    fig.xaxis[0].formatter = bkmod.NumeralTickFormatter(format="0,0")

In [None]:
def detect_and_plot_ibd_segments(
    self,
    sample_sets,
    sample_query,
    cohort_id,
    site_mask,
    min_ibd1_snps,
    out_dir="../../results/ibd/",
    contigs=('2RL', '3RL', 'X'),
):
    # Detect IBD1 segments
    for contig in contigs:
        
        n_inds = detect_ibd_segments(
            self=self,
            contig=contig, 
            sample_sets=sample_sets, 
            sample_query=sample_query,
            cohort_id=cohort_id,
            out_dir=out_dir,
            site_mask=site_mask, 
        )
    
    # Load IBD zarr data
    ibd_df = load_ibd_zarr_to_dataframe(
        self=self, 
        contigs=contigs,
        cohort_id=cohort_id,
        n_inds=n_inds,
        out_dir=out_dir, 
    )
    
    # Calculate genome fraction in IBD1
    ibd_fractions_df = fraction_genome_ibd(
        self=self, 
        ibd_df=ibd_df,  
        min_ibd1_snps=min_ibd1_snps, 
        out_dir=out_dir, 
        cohort_id=cohort_id
    )
     
    # Plot IBD segments
    plot_ibd_segments(
        self=self,
        df=ibd_df,
        min_ibd1_snps=min_ibd1_snps,
        show=False,
        out_dir=out_dir,
        cohort_id=cohort_id
        )
    
    return ibd_fractions_df

In [None]:
ag3 = malariagen_data.Ag3(
    pre=True, 
    results_cache="malariagen_data_cache", 
    simple_cache=dict(cache_storage="gcs_cache"))

#### Pre-processing

The functions are much much quicker (10x ish) if we remove invariant sites and singletons and doubletons. The size of the array reduces from around 80million SNPs to about 5million by doing this. 

We dont need singletons, as by definition they cannot break a scan which looks for positions which are homozygous but different. Similarly, doubletons could only break it if the both alleles of the doubleton are found in the same individual mosquito - very unlikely. 

### Run algorithm on each contig
Which will write to a .tsv file (large cohorts use too much memory to store as a pd.DataFrame. 

In [None]:
cohort_id = 'gaard_tz_muleba'
sample_sets='1246-VO-TZ-KABULA-VMF00185'
sample_query="location == 'Muleba'"
site_mask='arab' 
contigs = ('2RL', '3RL', 'X')

contig='X'
out_dir="../../results/ibd/"

In [None]:
dataset = 'gaard_tz_muleba' 

ibd_fraction_df = detect_and_plot_ibd_segments(
        self=ag3,
        sample_sets='1246-VO-TZ-KABULA-VMF00185', 
        sample_query="location == 'Muleba'",
        cohort_id='gaard_tz_muleba',
        min_ibd1_snps=1000,
        site_mask='arab', 
    )

### Bokeh plots

In [None]:
# plot_ibd_segments(
#     ag3,
#     ibd1_df,
#     min_ibd1_snps=1000,
#     sort_by='size'
# )