In [1]:
import os
import pysam
import pickle as pkl
import numpy as np
import pandas as pd
from tqdm import tqdm

import seaborn as sns
import matplotlib.pyplot as plt

from scipy.stats import norm
# import statsmodels.api as sm

from sklearn.pipeline import Pipeline

import ipywidgets as widgets
from ipywidgets import interact
from IPython import embed

import re

In [2]:
from genome_helpers import (
#     # compute_af,
#     inverse_rank_normalization,
    read_annotation_data,
    process_plant_phenotype,
#     get_genome_metadata
)

from gwas_helpers import (
    adj_phenotypes_for_gwas,
    annotate_results,
    run_gwas,
    qqplot,
    manhattan_static,
    manhattan_interactive
)

Para obtener la frecuencia de los alelos alternativo, usamos el campo `DP4` del archivo VCF, con las siguientes componentes:

| Campo   | Descripción                                      |
|---------|------------------------------------------------|
| DP4[0]  | Reads forward para el alelo de referencia.    |
| DP4[1]  | Reads reverse para el alelo de referencia.    |
| DP4[2]  | Reads forward para el alelo alternativo.      |
| DP4[3]  | Reads reverse para el alelo alternativo.      |


In [3]:
def compute_af(record):
    dp4 = record.info['DP4']
    if sum(dp4) == 0:
        return None
    allele_freq = (dp4[2]+dp4[3]) / sum(dp4)
    return allele_freq


def get_dp4(record):
    
    dp4 = record.info['DP4']
    if sum(dp4) == 0:
        return None
    
    return dp4


def get_freqs_from_vcf(vcf_file):
    
    bgi_id = os.path.basename(vcf_file).replace(".vcf.gz", "")

    try:
        vcf_file = pysam.VariantFile(vcf_file)
    except Exception as e:
        print(e)
        return {bgi_id: None}
        
    variant_freq   = {(record.chrom, record.pos): compute_af(record) for record in vcf_file.fetch()}
    variant_counts = {(record.chrom, record.pos): get_dp4(record)    for record in vcf_file.fetch()}

    # variant_dict.update({sample_info: set(variant_positions)})
    # variant_freq_dict.update({bgi_id: variant_freq})

    return {bgi_id: variant_freq}, {bgi_id: variant_counts}


def get_sample_info(sample_id, metadata):

    sample_info = metadata.loc[metadata.BGI_ID == sample_id, ["generation", "rep", "treatment"]]    

    if len(sample_info) == 1:
        sample_info = sample_info.iloc[0].to_list()
        sample_info = tuple(sample_info)
    elif len(sample_info) == 0:
        print(ValueError(f"Sample with {sample_id} has no corresponding metadata.")) 
        return None
    else:
        raise(ValueError(f"Sample with {sample_id} has more than one associated sample."))

    return sample_info


def process_vcf_folder(vcf_folder, cache_file="vcf_dictionary.pkl", as_dataframe=False):

    """
    
    """
    
    files = [ file for file in os.listdir(vcf_folder) if file.endswith(".vcf.gz") or file.endswith(".vcf") ]

    variant_freq_dicts = {}
    variant_count_dicts = {}
    
    for file in tqdm(files):
           
        sample_id = file.replace(".vcf.gz", "")
        # sample_info = get_sample_info(sample_id, genome_metadata)        
        variant_freq_dict, variant_count_dict = get_freqs_from_vcf(f"{vcf_folder}/{file}")

        variant_freq_dicts.update(variant_freq_dict)
        variant_count_dicts.update(variant_count_dict)
                
    if as_dataframe:
        return pd.DataFrame(variant_freq_dicts), pd.DataFrame(variant_count_dicts)
    else:
        return variant_freq_dicts, variant_count_dicts

In [4]:
def _process_bases(row):
    
    base_calls = row["bases"]
    quality_scores = phred_quality(row["qual"])
    ref_base = row["ref"].upper()

    base_calls = base_calls.replace("$", "").replace("^]", "").replace("^I", "")
    
    processed_bases = []
    passing_quality_scores = []
    
    i = 0
    while i < row["depth"]:
        char = base_calls[i]

        # Filtrar por calidad (descartar bases con calidad < 20)
        if quality_scores[i] < 20:
            i += 1
            continue

        passing_quality_scores.append(quality_scores[i])
        
        # Reemplazar referencia
        if char in ".,":  
            processed_bases.append(ref_base)

        # Contar bases normales
        elif char.upper() in "ACTG":
            processed_bases.append(char.upper())

        # Contar deleciones en la referencia (`*`)
        elif char == "*":
            processed_bases.append("D")

        # Detectar inserciones (`+nX`)
        elif char == "+":
            match = re.match(r"\+(\d+)", base_calls[i:])
            if match:
                num_bases = int(match.group(1))
                inserted_seq = base_calls[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]
                # processed_bases.append(f"INS_{inserted_seq.upper()}")
                processed_bases.append("I")
                i += len(match.group(1)) + num_bases

        # Detectar deleciones (`-nX`)
        elif char == "-":
            match = re.match(r"\-(\d+)", base_calls[i:])
            if match:
                num_bases = int(match.group(1))
                deleted_seq = base_calls[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]
                processed_bases.append("D")
                # processed_bases.append(f"DEL_{deleted_seq.upper()}")
                i += len(match.group(1)) + num_bases

        i += 1

    if len(passing_quality_scores) == 0:
         allele_counts = { 'A': 0, 'C': 0, 'T': 0, 'G': 0, 'DEL': 0, 'INS': 0, "Avg_Qual": None, "depth": 0, "depth_high_q": 0 }
         return allele_counts

    allele_counts = {
        'A': processed_bases.count('A'),
        'C': processed_bases.count('C'),
        'T': processed_bases.count('T'),
        'G': processed_bases.count('G'),
        'DEL': processed_bases.count('D'),
        'INS': processed_bases.count('I')
    }

    # Contar inserciones y deleciones específicas
    for item in processed_bases:
        if item.startswith("INS_"):
            allele_counts[item] = allele_counts.get(item, 0) + 1
        if item.startswith("DEL_"):
            allele_counts[item] = allele_counts.get(item, 0) + 1

    # Calcular calidad promedio
    # allele_counts["Avg_Qual"] = np.mean(quality_scores) if quality_scores else 0
    allele_counts["Avg_Qual"] = sum(passing_quality_scores) / len(passing_quality_scores) # if quality_scores else 0
    allele_counts["depth"] = row["depth"]
    allele_counts["depth_high_q"] = len(passing_quality_scores)
    allele_counts["bases"] = processed_bases
    allele_counts["original_bases"] = row["bases"]
    allele_counts["ref_base"] = ref_base
    allele_counts["quality_scores"] = quality_scores
    
    return allele_counts


def mismatch(row):
    alleles = [ row[x] for x in ['A', 'C', 'T', 'G', 'DEL', 'INS'] ] #, row.C, row.T, row.G]#, row.DEL, row.INS]
    # print(sum(alleles) != row.depth_high_q)
    return sum(alleles) != row.depth_high_q

In [6]:
phenotypes = ["PR_Length", "LR_number", "LR_Density"]
phenotypes_df, _, _, phenotypes_df_red_nb = process_plant_phenotype()
phenotypes_df = adj_phenotypes_for_gwas(phenotypes_df, phenotypes_df_red_nb, phenotypes)

- **Input**: matriz con cuentas para cada alelo, donde `NaN` significa que el SNV no tiene _variaciones o cobertura_.
- Ver cobertura para SNVs que no aparecen en los VCFs.
- Filtrar variantes.
- filas:(planta,SNV),columnas:None -> filas:planta,columnas:SNV.

In [10]:
EXPERIMENTS = [(1, 'K'), (1, 'MS'), (2, 'MS'), (2, 'K'), (3, 'K'), (3, 'MS')]

In [11]:
def filter_variants(freq_df, freq_threshold=0.98, non_missing=450):
    
    freq_df = freq_df[ freq_df.apply(lambda row: ~(row.dropna() > freq_threshold).all(), axis=1) ]
    freq_df = freq_df[ freq_df.apply(lambda row: row.isna().sum(), axis=1) < non_missing ]
    
    return freq_df

In [12]:
def display_df(df, n=5, disable=False, text="", info=True):
    
    if not disable:        
        if text: print(text)
        display(df.sample(n))
        print(f"{df.shape=}")
        if info: print(df.info())
        print("-"*100)

    return df

In [None]:
VERBOSE = False
VCF_DIR = "data/genomes/alignments_paired_end_new/"

# VCF_DIR = "./data/genomes/alignments_paired_end/"

batch_mapping = get_genome_metadata(as_dataframe=False)
freq_df, counts_df = process_vcf_folder(vcf_folder=VCF_DIR, as_dataframe=True)
depth_df  = counts_df.map(lambda x: sum(x) if isinstance(x, tuple) else x)

freq_df = (
    filter_variants(freq_df)
    .pipe(display_df, text="Pre-variant filtering", disable=not VERBOSE)
    .pipe(filter_variants)
    .pipe(display_df, text="Filtered variants", disable=not VERBOSE)
    .melt(ignore_index=False)
    .pipe(display_df, text="Reshaped", disable=not VERBOSE)
    .reset_index()
    .pipe(display_df, text="Reset index", disable=not VERBOSE)
    .rename({"level_0": "contig", "level_1": "position", "value": "freq"}, axis=1)
    .pipe(display_df, text="Renamed columns", disable=not VERBOSE)
)

# We map the batch to the sample so that we can pair it up with the phenotype values from the other file
samples = freq_df.variable.apply(lambda x: batch_mapping.get(x, (None, None, -1)))
samples_df = pd.DataFrame(samples.to_list(), columns=["treatment", "replica", "generation"])

freq_df = pd.concat([ freq_df, samples_df ], axis=1)
freq_df = freq_df.query("generation != -1")
freq_df.replica = freq_df.replica.astype(int)
freq_df = freq_df[~freq_df.freq.isna()]

  0%|          | 0/492 [00:00<?, ?it/s]

In [15]:
filtered_df = freq_df.pivot(
    columns=['replica', 'treatment'], 
    index=['contig', 'position', 'generation'], 
    values='freq'
).reset_index()

variant_names = pd.Series(zip(filtered_df['contig'], filtered_df['position']))
generations = filtered_df.generation

allele_freqs = filtered_df.loc[:, [1, 2, 3]]

data = [ variant_names, generations, allele_freqs ]
data = pd.concat(data, axis=1).sort_values("generation")
data.columns = ['Variant', 'Generation'] + data.columns[2:].to_list()

# Number of generations where a given SNV is present
filtered_variants = data.groupby('Variant')['Generation'].nunique()

# Keep only those variants that are present in at least 10 generations
filtered_variants = filtered_variants[filtered_variants >= 10].index
filtered_data = data.set_index("Variant").loc[filtered_variants].reset_index()

TOP_N_VARIANTS = 200
changing_variants = filtered_data.groupby("Variant")[[(1, 'MS')]].std().iloc[:,0].sort_values(ascending=False)[:TOP_N_VARIANTS].index
top_changing_variants_df = filtered_data.set_index("Variant").loc[changing_variants]
freq_data = top_changing_variants_df.reset_index()

variants_lst    = [list(x) for x in top_changing_variants_df.index.unique()]
display_options = [f'{item[0]} - {item[1]}' for item in sorted(variants_lst)]
value_dict      = {f'{item[0]} - {item[1]}': item for item in sorted(variants_lst)}

genotype_data = freq_data.melt(
    id_vars=["Variant", "Generation"], 
    value_vars=EXPERIMENTS, 
    var_name="replica", value_name="af"
)

genotype_data["id"] = genotype_data.apply(lambda x: (x.replica[1], x.replica[0], x.Generation), axis=1)
genotype_data = genotype_data.drop(["Generation", "replica"], axis=1)
genotype_data_wide = genotype_data.pivot(index="id", columns="Variant")
# genotype_data_wide[genotype_data_wide.isna()] = 0
genotype_data_wide = genotype_data_wide.reset_index()
genotype_data_wide.id = genotype_data_wide.id.apply(lambda x: (x[0], int(x[1]), int(x[2])))
genotype_data_wide = genotype_data_wide.set_index("id")

# Remove rows for which all values are NaN, which means that the VCF was not present.
genotype_data_wide = genotype_data_wide[~genotype_data_wide.isna().all(axis=1)]

common_rows = genotype_data_wide.reset_index().id.isin(set(phenotypes_df.index))

In [16]:
# genotype_data_wide_K = genotype_data_wide[['K' in x for x in genotype_data_wide.index]]

In [17]:
assert isinstance(genotype_data_wide.columns[0], tuple) and isinstance(genotype_data_wide.columns[0][1], tuple), "It seems like this has already been run"

genotype_data_wide = (genotype_data_wide
  .reset_index()
  .loc[common_rows]
  .set_index('id')
  .pipe(lambda df: df.set_axis([df.columns[i][1] for i in range(TOP_N_VARIANTS)], axis=1) )
)

genotype_data_wide

Unnamed: 0_level_0,"(contig000001, 146)","(contig000001, 227)","(contig000001, 5306)","(contig000001, 91539)","(contig000001, 91557)","(contig000001, 91560)","(contig000001, 91603)","(contig000001, 91654)","(contig000001, 91737)","(contig000001, 91827)",...,"(contig000029, 128)","(contig000029, 302)","(contig000029, 320)","(contig000029, 329)","(contig000029, 344)","(contig000029, 695)","(contig000032, 8535)","(contig000038, 58)","(contig000038, 105)","(contig000038, 155)"
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(K, 1, 1)",,0.819048,,,,,,,,,...,,,,,,0.994898,,,0.423077,
"(K, 1, 4)",0.985915,1.000000,1.000000,,,,,,,,...,,,,,,1.000000,,,0.900000,
"(K, 1, 5)",,,1.000000,,,,,,,,...,,,,,,1.000000,0.870813,,0.606061,
"(K, 1, 6)",0.941860,0.981982,0.770833,,,,,,,0.713483,...,,,,,,0.998450,,,0.722222,
"(K, 1, 7)",1.000000,1.000000,1.000000,,,,,,,,...,,,,,,1.000000,,,,0.962963
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"(MS, 3, 78)",0.869565,0.878788,,0.794326,0.758065,0.756972,0.755396,,,,...,,,,,,1.000000,,,0.480000,
"(MS, 3, 79)",1.000000,1.000000,1.000000,,,,,,,,...,,,,,,1.000000,,,0.986111,
"(MS, 3, 80)",,0.873846,1.000000,,,,,,,,...,,,,,,0.998273,,,0.514286,
"(MS, 3, 81)",0.871134,0.959381,,,,,,,,,...,,,,,,0.997468,,,0.611940,


In [56]:
all_data = pd.merge(phenotypes_df, genotype_data_wide, left_index=True, right_index=True)
# all_data

### GWAS: one plant, one data point

In [19]:
snps = all_data.columns[3:]

In [20]:
gwas_results = run_gwas(all_data, snps, phenotypes)
gwas_results.shape

(600, 7)

In [55]:
# qqplot(gwas_results.query("phenotype == 'PR_Length'").p_value)
# qqplot(gwas_results.query("phenotype == 'LR_number'").p_value)
# qqplot(gwas_results.query("phenotype == 'LR_Density'").p_value)

In [33]:
@interact
def select_association(index=widgets.IntSlider(min=0,max=100)):
    
    SNP       = gwas_results.SNP.iloc[index]
    phenotype = gwas_results.phenotype.iloc[index]
    p_value   =  gwas_results.p_value.iloc[index]
    
    plt.figure(figsize=(15, 5))
    plt.scatter(all_data[[SNP]], all_data[[phenotype]]);
    plt.title(f"{SNP}\n{phenotype} ({p_value:.1e})");

interactive(children=(IntSlider(value=0, description='index'), Output()), _dom_classes=('widget-interact',))

___

In [23]:
# gff = "/home/rodrigo/01_repos/plant-microbiota-interaction/data/genomes/reference_2/annotations/annotations2.gff3"

In [39]:
gff = "/home/rodrigo/01_repos/plant-microbiota-interaction/data/genomes/reference/annotations_translated_from_tg1e1.gff"
gff_data = read_annotation_data(gff)

          seqid   source  type  start   end score strand phase  \
0  contig000001  Liftoff  gene   2802  3266     .      -     .   
1  contig000001  Liftoff   CDS   2802  3266     .      -     .   
2  contig000001  Liftoff  gene   3483  3905     .      +     .   
3  contig000001  Liftoff   CDS   3483  3905     .      +     .   
4  contig000001  Liftoff  gene   4159  5214     .      -     .   

                                          attributes  
0  ID=gene-C2I27_06630;Name=C2I27_06630;gbkey=Gen...  
1  ID=cds-PVC74168.1;Parent=gene-C2I27_06630;Dbxr...  
2  ID=gene-C2I27_06635;Name=C2I27_06635;gbkey=Gen...  
3  ID=cds-PVC74169.1;Parent=gene-C2I27_06635;Dbxr...  
4  ID=gene-C2I27_06640;Name=C2I27_06640;gbkey=Gen...  


In [54]:
gwas_results

Unnamed: 0,SNP,phenotype,p_value,beta,r_squared,contig,position,annotation,Chromosome,annotation_as_str
447,"(contig000024, 14027)",PR_Length,6.435899e-14,-0.581154,3.593671e-02,contig000024,14027,,contig000024,
564,"(contig000028, 15902)",PR_Length,2.075721e-12,5.104474,3.116754e-02,contig000028,15902,,contig000028,
354,"(contig000024, 13733)",PR_Length,3.583436e-11,-52.074402,1.338920e-02,contig000024,13733,,contig000024,
543,"(contig000026, 12532)",PR_Length,5.100639e-10,-10.822059,1.844451e-01,contig000026,12532,,contig000026,
518,"(contig000024, 20950)",LR_Density,1.310338e-09,12.457027,1.316118e-02,contig000024,20950,,contig000024,
...,...,...,...,...,...,...,...,...,...,...
169,"(contig000022, 6781)",LR_number,9.755961e-01,-0.020619,4.861155e-06,contig000022,6781,"{'ID': 'gene-C2I27_09295', 'Name': 'C2I27_0929...",contig000022,
130,"(contig000010, 110356)",LR_number,9.783733e-01,0.019803,2.573943e-06,contig000010,110356,,contig000010,
40,"(contig000001, 757080)",LR_number,9.819319e-01,-0.017914,1.723985e-06,contig000001,757080,,contig000001,
201,"(contig000024, 13039)",PR_Length,9.980114e-01,0.009901,1.898761e-09,contig000024,13039,,contig000024,


In [42]:
gwas_results["annotation"] = annotate_results(gwas_results, gff_data)

Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to any row.
Variant does not belong to a

In [43]:
# manhattan_static(gwas_results)
manhattan_interactive(gwas_results)

In [None]:
# results_df.phenotype.unique()
# results_per_pheno_df = results_df.query("phenotype == 'LR_Density'")
# results_per_pheno_df = results_per_pheno_df.apply(lambda x: str(x))

In [83]:
from Bio import SeqIO

def get_contig_range(input_fasta, contig_name, start=None, end=None):
    """
    Retrieves a specific contig (and optionally a range within it) as a string.

    Parameters:
    - input_fasta: str, path to the input FASTA file.
    - contig_name: str, the name of the contig to query.
    - start: int or None, start position (1-based, inclusive). Default is None for full contig.
    - end: int or None, end position (1-based, inclusive). Default is None for full contig.

    Returns:
    - str: The extracted sequence as a string.
    - None: If the contig is not found.
    """
    for record in SeqIO.parse(input_fasta, "fasta"):
        if record.id == contig_name:
            if start is not None and end is not None:
                return str(record.seq[start - 1:end])
            return str(record.seq)
    print(f"Contig '{contig_name}' not found in {input_fasta}")
    return None

In [86]:
# Example usage
input_file = "data/genomes/reference_2/full_sequence.fasta"
contig = "PRKV01000004.1"
start_position = 394240 - 200
end_position = 394240 + 200

get_contig_range(input_file, contig, start_position, end_position)

'ATAAAAAAGCCAAAAAAACCTAAAATGCTGACACAGATAATACTAAAAAAGCGGATCAGTGAAATTGTTCCTATAAGGGACTGATTCGCTAGCGTAAAGGTAGCAATCGTAGAGCCTGCAATGACAACAAGCATGGCAGGACTTGTTAAACCGGCTCTAATGGCTGCATCACCGATAATAAGACCTCCGATTACACTGAGTGTCTGCCCGACAGAAGTCGGGAGTCTAAACCCTGCCTCTCGAAACAATTCAAATAAAAGAAGCATTAGTATCGCTTCAAGAGAGGTAGGAAACGGGACACCTCTTCTTGCCTCAACAATCGTAGCTAATAAGCTGAGCGGAAGCTGATTTTGATGAAAAGCCGTCATTGCAACCCAAAAACCAGGCAAAAAAGCAGCGAT'