In [1]:
import os

import allel
from numba import njit
import malariagen_data
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

### In this notebook I will develop a numba-based method to detect IBD segments from unphased data.

The method is loosely based on the following studies:   
https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0034267  
https://www.sciencedirect.com/science/article/pii/S0002929720300549  

Essentially, we take two individuals and scan across their genomes. We aim to identify segments in so-called half-IBD (IBD1), where one chromosome is shared, and full-IBD or IBD2, where both chromosomes are shared. 

To detect IBD1 segment breakpoints, we record the position of any SNP where the two individuals are homozygous but different alleles - this has to be an IBD1 breakpoint.

To detect IBD2 segment breakpoints, we record the position of any SNP where alleles are different, i.e we are looking for stretches where everything is identical.

We then apply a size filter to remove small segments which are identical by state but probably not identical by descent. 

In [2]:
@njit()
def ibdscans(geno, pos):
    """
    Scans across two genotype vectors, checking and recording SNPs which are either 
    are homozygous but different (IBD1 breakpoint) or have any alleles different (IBD2 breakpoint)
    """

    geno1 = geno[:, 0]
    geno2 = geno[:, 1]

    ibd1_breakpoints = []
    ibd2_breakpoints = []
    # Scan right along genome, as long as two inds are not both homozygous but different
    for pos_idx in range(pos.shape[0]):
        
        gn1 = geno1[pos_idx]
        gn2 = geno2[pos_idx]        
        
        # if both alleles are not identical, record the position
        if not (gn1 == gn2).all() and (-1 not in gn1 and -1 not in gn2):
            ibd2_breakpoints.append(pos_idx)
            # if both alleles are not identical and not missing, and both are homozygous, record the position
            if (gn1[0] == gn1[1]) and (gn2[0] == gn2[1]) and ((gn1 != gn2).all()) and (-1 not in gn1 and -1 not in gn2):
                ibd1_breakpoints.append(pos_idx)
   
    return np.array(ibd1_breakpoints), np.array(ibd2_breakpoints)

def detect_ibd_segments(contig, sample_sets=None, sample_query=None, prefix="", site_mask='gamb_colu', min_size=5000):
    """
    Loads genotype data, and runs IBD scan on each pair of individuals, writing IBD segments to 
    a .tsv file.
    """
    from itertools import combinations
    
    geno, pos, sample_ids = load_genotypes(
        contig=contig, 
        sample_sets=sample_sets, 
        sample_query=sample_query, 
        site_mask=site_mask
    )
    
    for x,y in tqdm(combinations(range(geno.shape[1]), 2)):
        sample1 = sample_ids[x]
        sample2 = sample_ids[y]
        ibd1, ibd2 = ibdscans(geno.take([x,y], axis=1), pos)
        
        ## IBD 1
        pos_ibd1 = pos[ibd1]
        n_snps_ibd1 = np.ediff1d(ibd1)
        sizes_ibd1 = np.ediff1d(pos_ibd1)
        ibd1_segments_df = pd.DataFrame({'size':sizes_ibd1,
                                      'n_snps':n_snps_ibd1,
                                      'start':pos_ibd1[:-1], 
                                      'end':pos_ibd1[1:]})
        ibd1_segments_df = ibd1_segments_df.query("size > @min_size")
        ibd1_segments_df.assign(
            contig=contig, 
            index1=x, 
            index2=y,
            sample1=sample1,
            sample2=sample2).to_csv(f"{prefix}.{contig}.ibd1.tsv", sep="\t", mode='a', header=not os.path.exists(f"{prefix}.{contig}.ibd1.tsv"))
        
        ## IBD 2 
        pos_ibd2 = pos[ibd2]
        n_snps_ibd2 = np.ediff1d(ibd2)
        sizes_ibd2 = np.ediff1d(pos_ibd2)
        ibd2_segments_df = pd.DataFrame({'size':sizes_ibd2,
                                      'n_snps':n_snps_ibd2,
                                      'start':pos_ibd2[:-1], 
                                      'end':pos_ibd2[1:]})
        ibd2_segments_df = ibd2_segments_df.query("size > @min_size")
        ibd2_segments_df.assign(
            contig=contig, 
            index1=x, 
            index2=y,
            sample1=sample1,
            sample2=sample2).to_csv(f"{prefix}.{contig}.ibd2.tsv", sep="\t", mode="a", header=not os.path.exists(f"{prefix}.{contig}.ibd2.tsv"))

        
def load_genotypes(contig, sample_sets, sample_query, site_mask='gamb_colu'):
    """
    Loads SNP calls and throws out invariant, singleton and doubleton sites, which
    massively speeds up IBD scanning algorithm.
    """
    
    print("Loading genotypes")
    ds_snps = ag3.snp_calls(region=contig, sample_sets=sample_sets, sample_query=sample_query, site_mask=site_mask)
    geno = allel.GenotypeDaskArray(ds_snps['call_genotype'].values)
    sample_ids = ds_snps['sample_id'].values
    print("computing allele counts")
    ac = geno.count_alleles().compute()

    seg = ac.is_segregating()
    dbl = ac.is_doubleton(allele=1)
    dbl2 = ac.is_doubleton(allele=2)
    dbl3 = ac.is_doubleton(allele=3)
    dbl = np.logical_or(dbl, np.logical_or(dbl2, dbl3))
    sngl = ac.is_singleton(allele=1)
    sngl2 = ac.is_singleton(allele=2)
    sngl3 = ac.is_singleton(allele=3)
    
    mask = np.logical_and(seg, ~np.logical_or(dbl, np.logical_or(sngl, np.logical_or(sngl2, sngl3))))
    print(f"retaining {mask.sum()} variants and removing {(~mask).sum()}")
    
    geno = geno[mask].compute().values
    pos = ds_snps['variant_position'].values[mask]
    
    return geno, pos, sample_ids


In [3]:
# from google.colab import drive
# import os

# drive.mount("drive")
# results_dir = "drive/MyDrive/Colab Data/ibd-detection"
# os.makedirs(results_dir, exist_ok=True)

ag3 = malariagen_data.Ag3(pre=True, results_cache="malariagen_data_cache")

#### Pre-processing

The functions are much much quicker (10x ish) if we remove invariant sites and singletons and doubletons. The size of the array reduces from around 80million SNPs to about 5million by doing this. 

We dont need singletons, as by definition they cannot break a scan which looks for positions which are homozygous but different. Similarly, doubletons could only break it if the both alleles of the doubleton are found in the same individual mosquito - very unlikely. 

### Run algorithm on each contig
Which will write to a .tsv file (large cohorts use too much memory to store as a pd.DataFrame. 

In [None]:
for contig in ag3.virtual_contigs + ('X',):
    detect_ibd_segments(
        contig=contig, 
        sample_sets='AG1000G-BF-A', 
        sample_query="taxon == 'gambiae'",
        site_mask='gamb_colu', 
        prefix="obuasi"
    )

Loading genotypes


### Summarising functions (WIP)

In [None]:
def summarise_ibd_data(prefix='coluzzii', ibd1_min_size=10000, ibd2_min_size=100):
    from dask import dataframe as dd
    import malariagen_data
    
    ag3 = malariagen_data.Ag3("gs://vo_agam_release/", pre=True)
    genome_size = np.sum([ag3.genome_sequence(contig).shape[0] for contig in ag3.virtual_contigs + ('X',)])

    ibd1_list = []
    ibd2_list = []
    for i, contig in enumerate(ag3.virtual_contigs + ('X',)):
        print(f"reading ibd1 {contig}...")
        ibd1_df = dd.read_csv(f"{prefix}.{contig}.ibd1.tsv", sep="\t").query(f"n_snps > {ibd1_min_size}")
        print(f"reading ibd2 {contig}...")
        ibd2_df = dd.read_csv(f"{prefix}.{contig}.ibd2.tsv", sep="\t").query(f"n_snps > {ibd2_min_size}")
        
        ibd1_data = ibd1_df.groupby(['index1', 'index2']).agg({'size':'sum', 'n_snps':'sum'}).assign(contig=contig)
        ibd2_data = ibd2_df.groupby(['index1', 'index2']).agg({'size':'sum', 'n_snps':'sum'}).assign(contig=contig)
        ibd1_list.append(ibd1_data)
        ibd2_list.append(ibd2_data)
    
    print("computing...")
    ibd1_stats = dd.concat(ibd1_list).compute()
    ibd2_stats= dd.concat(ibd2_list).compute()
    
    ibd1_stats.to_csv("ibd1.per_contig.tsv", sep="\t")
    ibd2_stats.to_csv("ibd2.per_contig.tsv", sep="\t")
    
    ibd1_stats = ibd1_stats.reset_index().rename(columns={'size':'ibd1_size'}).groupby(['index1', 'index2']).agg({'ibd1_size':'sum'}).assign(ibd1_fraction=lambda x: x['ibd1_size']/genome_size)
    ibd2_stats = ibd2_stats.reset_index().rename(columns={'size':'ibd2_size'}).groupby(['index1', 'index2']).agg({'ibd2_size':'sum'}).assign(ibd2_fraction=lambda x: x['ibd2_size']/genome_size)
    ibd_stats = pd.concat([ibd1_stats, ibd2_stats], axis=1)
    
    ibd_stats.to_csv("ibd.summary.tsv", sep="\t")
    return ibd_stats

In [None]:
ibd1_stats, ibd2_stats = summarise_ibd_data(prefix='obuasi')

### Plotting function (WIP)

Need to make this which plots all chromosomes together. 

In [None]:
def plot_ibd_segments(contig, ibd1_df, ibd2_df, title=None, figsize=(12,2)):
    from matplotlib.lines import Line2D
    import matplotlib.pyplot as plt
    from matplotlib import ticker
    
    fig, ax = plt.subplots(1,1, figsize=[10,5]) # just changed with checking if still works
    
    contig_len = ag3.genome_sequence(contig).shape[0]

    y_pos = np.linspace(0.05, 1, contig_len)
    
    ax.set_xlim(0, contig_len)
    ax.set_yticks([])
    print("plotting now")
    for i, (x,y) in enumerate(zip(ibd1_df.index1.unique(), ibd1_df.index2.unique())):
        for _, row in ibd1_df.query("index1 == @x and index2 == @y").iterrows():
            line_height = y_pos[i]

            start, end = row[['start', 'end']]
            ibd_line = Line2D(xdata=(start, end), ydata=(line_height, line_height), color='grey')
            ax.add_line(ibd_line)

            for _, row in ibd2_df.query("index1 == @x and index2 == @y").iterrows():
                start, end = row[['start', 'end']]
                ibd_line = Line2D(xdata=(start, end), ydata=(line_height, line_height))
                ax.add_line(ibd_line)   

    if title: ax.set_title(title)
    ax.set_xticks(np.arange(0, contig_len, 10e6))
    # Use the FuncFormatter class to set the tick labels
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_func))

    sns.despine(ax=ax, left=True)
    fig.show()


# Define a function to format the tick labels
def format_func(value, tick_number):
    # Convert the tick value from bytes to megabytes
    mb_value = value / 1e6
    # Format the tick label as a string with 1 decimal place and the "Mb" suffix
    return f"{mb_value:.1f} Mb"