In [None]:
import allel
import pandas as pd
import numpy as np
import plotly.express as px

In [None]:
metadata_path = '../../config/metadata.tsv'
bed_targets_path = "../../config/ag-vampir.bed"
vcf_path = "../../results/vcfs/targets/ampseq-vigg-01.annot.vcf"
wkdir = "../.."

sample_total_reads_threshold = 250
amplicon_total_reads_threshold = 1000

# Sample quality control 

In this notebook, we perform quality control on samples, removing samples with very low depth or elevated heterozygosity. 

In [None]:
metadata = pd.read_csv(metadata_path , sep="\t")

panel_metadata = pd.read_csv(
    bed_targets_path, 
    sep="\t", 
    header=None, 
    names=['contig', 'start', 'end', 'amplicon', 'mutation']
)

vcf = allel.read_vcf(vcf_path, fields='*')

samples = vcf['samples']
contigs = vcf['variants/CHROM']
geno = allel.GenotypeArray(vcf['calldata/GT'])
pos = vcf['variants/POS']
ref = vcf['variants/REF']
depth = vcf['variants/DP']
qual = vcf['variants/QUAL']

## Coverage data

In [None]:
target_covs = []
x_ratios = []
for sample in metadata.sampleID:
    target_cov = pd.read_csv(f"{wkdir}/results/coverage/{sample}.regions.bed.gz", sep="\t", header=None, names=['contig', 'start', 'end', 'amplicon', 'depth', 'sampleID'])
    target_cov = target_cov.assign(sampleID=sample)
    target_covs.append(target_cov)
    
    # x-autosome ratio
    contig_depth = target_cov.groupby('contig').agg({'depth':'sum'})
    x_ratios.append((contig_depth.loc[['2L', '2R', '3L', '3R']].sum() / contig_depth.loc['X']).iloc[0])
    
target_cov_df = pd.concat(target_covs, axis=0)
target_cov_df = target_cov_df.merge(panel_metadata, how='left', on=['contig', 'start', 'end', 'amplicon'])

sample_cov_df = target_cov_df.groupby('sampleID').agg({'depth':'sum'}).reset_index()

fig = px.histogram(sample_cov_df, x='depth', nbins=500, template='simple_white', 
                   width=800, height=300, title='Histogram of total read counts per sample')
fig.show()

How many samples fall below the threshold for total reads?

In [None]:
exclude_samples_depth = sample_cov_df.query("depth < @sample_total_reads_threshold")['sampleID']
print(f"Removing {len(exclude_samples_depth)} samples due to low total depth")

#### Total reads per target SNP

In [None]:
amplicon_cov_df = target_cov_df.groupby('mutation').agg({'depth':'sum'}).reset_index()

fig = px.histogram(amplicon_cov_df, x='depth', nbins=200, color='mutation', template='simple_white', 
                   width=800, height=350, 
                   title='Histogram of total read counts per SNP target')
fig.show()

Which target SNPs have lower total depth than the amplicon threshold?

In [None]:
exclude_targets_depth = amplicon_cov_df.query("depth < @amplicon_total_reads_threshold")['mutation']
print(f"Removing {len(exclude_targets_depth)} target SNPs due to low total depth")

pd.DataFrame(exclude_targets_depth)

### Number of missing calls

In [None]:
exclude_samples_missing_calls = samples[(geno.is_missing().sum(axis=0) > 40)]
print(f"{len(exclude_samples_missing_calls)} samples have more than 40 missing calls overall out of all possible target SNPs")

a = exclude_samples_missing_calls
b = exclude_samples_depth

# how many genes are shared between the exclude missing calls and depth lists 
overlap = len(set(a) & set(b))

print(f"{overlap}/{len(exclude_samples_missing_calls)} of these are also present in the low depth samples to be excluded")

### Autosome / Sex chromosome coverage ratios

Females will have a lower ratio of autosomes:x, and males will have a higher ratio. Its not clear whether we can use this yet to sex samples.

In [None]:
x_ratio_df = pd.DataFrame({'sampleID':metadata.sampleID, 'x_ratio':x_ratios})
x_ratio_df = x_ratio_df.query("sampleID not in @exclude_samples_depth")

fig = px.histogram(x_ratio_df, x='x_ratio', color='sampleID', template='simple_white', nbins=1000, width=800, height=300)
fig.update_xaxes(range=(0,20), title=dict(text='Autosome / X depth ratio'))
fig.show()

### Sample heterozygosity

In [None]:
def calc_heterozygosity(gt, gt_samples):
    from tqdm.notebook import tqdm
    
    het_per_sample = [np.nanmean(allel.heterozygosity_observed(gt[:, [i], :])) for i in tqdm(range(gt.shape[1]))]
    het_df = pd.DataFrame({'sampleID':gt_samples, 'heterozygosity':het_per_sample})
    return het_df.set_index("sampleID")

het_df = calc_heterozygosity(gt=geno, gt_samples=samples).reset_index()
het_df = het_df.merge(metadata)

fig = px.bar(
    het_df, 
    x='sampleID', 
    y='heterozygosity', 
    color='location', 
    template='simple_white', 
    title="Individual sample heterozygosity", 
    height=400,
    width=900
)

fig2  = px.histogram(
    het_df, 
    x='heterozygosity', 
    color='location', 
    template='simple_white', 
    title="Histogram of sample heterozygosity", 
    height=400,
    width=900
)

fig.show()
fig2.show()

#### Locate heterozygosity outliers

We then find samples within each cohort which have a heterozygosity (2.5 * IQR) higher than the 75% quantile, to exclude samples with very high heterozygosity for their cohort.

In [None]:
from scipy.stats import iqr

iqr_multiplier = 2.5 # determines how strict we are in throwing out outliers 

exclude_samples_heterozygosity = []
for coh in het_df.location.unique():
    df = het_df.query("location == @coh")
    hets = df.heterozygosity
    
    threshold = np.nanquantile(hets, 0.75) + (iqr_multiplier * iqr(hets, nan_policy='omit'))
    
    if any(hets > threshold):
        exclude_samples_heterozygosity.extend(df.query("heterozygosity > @threshold").sampleID.to_list())
    
    print(f"For {coh} the heterozygosity threshold is {np.round(threshold, 3)}, out of {len(hets)} samples, {(hets > threshold).sum()} are outliers")

print(f"\nRemoving {len(exclude_samples_heterozygosity)} samples in total due to high heterozygosity")

### Summary of samples to exclude

In [None]:
exclude_samples = np.unique(exclude_samples_depth.to_list() + exclude_samples_heterozygosity + list(exclude_samples_missing_calls))
removed_metadata = metadata.query("sampleID in @exclude_samples").location.value_counts().to_frame().reset_index()

removed_metadata = removed_metadata.set_index('location').T
tot = removed_metadata.sum(axis=1)
removed_metadata = removed_metadata.assign(total=tot).T

removed_metadata.reset_index()

In [None]:
new_metadata = metadata.query("sampleID not in @exclude_samples")
new_metadata.to_csv(f"{wkdir}/results/config/metadata.qcpass.tsv", sep="\t")

####  Sample QC complete!
A new metadata file with low-quality samples removed has been written to results/config/ :)

In [None]:
### Variant Hardy-Weinberg equilibrium
# from itertools import combinations

# possible_alleles = [[0,0], [0,1], [1,1], [-1,-1], [1,2]]
# possible_alleles = np.unique(np.array(list(combinations(np.repeat([-1,0,1,2,3], 2), 2))), axis=0)

# from collections import defaultdict
# from tqdm.notebook import tqdm

# def count_alleles_to_df(geno, pos, contig):
#     from collections import defaultdict
#     from tqdm.notebook import tqdm

#     assert geno.shape[0] == pos.shape[0]

#     di = {}
#     for i, p in tqdm(enumerate(pos)):
#         counter = defaultdict(int)

#         for allele in possible_alleles:
#             allele_str = '/'.join(allele.astype(str))
#             for idx in range(geno.shape[1]):
#                 if all(geno[i, idx] == allele):
#                     counter[allele_str] += 1
#                 else:
#                     counter[allele_str] += 0
#                 di[f"{contig[i]}:{p}"] = counter
                
#     return pd.DataFrame(di).reset_index().rename(columns={'index':'genotype'})

# geno_count_df = count_alleles_to_df(geno=geno, pos=pos, contig=contigs)
# geno_count_df = geno_count_df.query("~genotype.str.contains('-1')")

# df = geno_count_df.set_index('genotype')

# import snphwe

# snphwe.snphwe(gn_counts[1], gn_counts[0],  gn_counts[2])

# not bulletproof - takes 3 most common counts for a given snp
# if quite multiallelic things probably go wrong 
# for var in df.columns:
#     allele_idxs = np.argpartition(df[var], -3)[-3:]
#     gn_counts = df[var].iloc[allele_idxs].sort_index()
    
#     if (gn_counts != 0).sum() == 1:
#         res = 'NaN'
#     else:
#         res = snphwe.snphwe(gn_counts[1], gn_counts[0],  gn_counts[2])
#     print(gn_counts, res, "\n")