In [None]:
import os
import pysam
import pickle as pkl
import numpy as np
import pandas as pd
from tqdm import tqdm

import seaborn as sns
import matplotlib.pyplot as plt

from scipy.stats import norm

from sklearn.pipeline import Pipeline

import ipywidgets as widgets
from ipywidgets import interact
from IPython import embed

import re
import random

In [None]:
from genome_helpers import (
#     # compute_af,
#     inverse_rank_normalization,
    read_annotation_data,
    process_plant_phenotype,
    get_genome_metadata
)

from gwas_helpers import (
    adj_phenotypes_for_gwas,
    annotate_results,
    run_gwas,
    qqplot,
    manhattan_static,
    manhattan_interactive
)

In [None]:
def compute_af(record):
    dp4 = record.info['DP4']
    if sum(dp4) == 0:
        return None
    allele_freq = (dp4[2]+dp4[3]) / sum(dp4)
    return allele_freq


def get_dp4(record):
    
    dp4 = record.info['DP4']
    if sum(dp4) == 0:
        return None
    
    return dp4


def get_freqs_from_vcf(vcf_file):
    
    bgi_id = os.path.basename(vcf_file).replace(".vcf.gz", "")

    try:
        vcf_file = pysam.VariantFile(vcf_file)
    except Exception as e:
        print(e)
        return {bgi_id: None}, {bgi_id: None}
        
    variant_freq   = {(record.chrom, record.pos): compute_af(record) for record in vcf_file.fetch()}
    variant_counts = {(record.chrom, record.pos): get_dp4(record)    for record in vcf_file.fetch()}

    # variant_dict.update({sample_info: set(variant_positions)})
    # variant_freq_dict.update({bgi_id: variant_freq})

    return {bgi_id: variant_freq}, {bgi_id: variant_counts}


def get_sample_info(sample_id, metadata):

    sample_info = metadata.loc[metadata.BGI_ID == sample_id, ["generation", "rep", "treatment"]]    

    if len(sample_info) == 1:
        sample_info = sample_info.iloc[0].to_list()
        sample_info = tuple(sample_info)
    elif len(sample_info) == 0:
        print(ValueError(f"Sample with {sample_id} has no corresponding metadata.")) 
        return None
    else:
        raise(ValueError(f"Sample with {sample_id} has more than one associated sample."))

    return sample_info


def process_vcf_folder(vcf_folder, cache_file="vcf_dictionary.pkl", as_dataframe=False):

    """
    
    """
    
    files = [ file for file in os.listdir(vcf_folder) if file.endswith(".vcf.gz") or file.endswith(".vcf") ]

    variant_freq_dicts = {}
    variant_count_dicts = {}
    
    for file in tqdm(files):
           
        sample_id = file.replace(".vcf.gz", "")
        # sample_info = get_sample_info(sample_id, genome_metadata)        
        variant_freq_dict, variant_count_dict = get_freqs_from_vcf(f"{vcf_folder}/{file}")

        variant_freq_dicts.update(variant_freq_dict)
        variant_count_dicts.update(variant_count_dict)
                
    if as_dataframe:
        return pd.DataFrame(variant_freq_dicts), pd.DataFrame(variant_count_dicts)
    else:
        return variant_freq_dicts, variant_count_dicts
    

def filter_variants(freq_df, freq_threshold=0.98, non_missing=450):
    
    freq_df = freq_df[ freq_df.apply(lambda row: ~(row.dropna() > freq_threshold).all(), axis=1) ]
    freq_df = freq_df[ freq_df.apply(lambda row: row.isna().sum(), axis=1) < non_missing ]
    
    return freq_df


def display_df(df, n=5, disable=False, text="", info=True):
    
    if not disable:        
        if text: print(text)
        display(df.sample(n))
        print(f"{df.shape=}")
        if info: print(df.info())
        print("-"*100)

    return df

In [None]:
def _process_bases(row):
    
    base_calls = row["bases"]
    quality_scores = phred_quality(row["qual"])
    ref_base = row["ref"].upper()

    base_calls = base_calls.replace("$", "").replace("^]", "").replace("^I", "")
    
    processed_bases = []
    passing_quality_scores = []
    
    i = 0
    while i < row["depth"]:
        char = base_calls[i]

        # Filtrar por calidad (descartar bases con calidad < 20)
        if quality_scores[i] < 20:
            i += 1
            continue

        passing_quality_scores.append(quality_scores[i])
        
        # Reemplazar referencia
        if char in ".,":  
            processed_bases.append(ref_base)

        # Contar bases normales
        elif char.upper() in "ACTG":
            processed_bases.append(char.upper())

        # Contar deleciones en la referencia (`*`)
        elif char == "*":
            processed_bases.append("D")

        # Detectar inserciones (`+nX`)
        elif char == "+":
            match = re.match(r"\+(\d+)", base_calls[i:])
            if match:
                num_bases = int(match.group(1))
                inserted_seq = base_calls[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]
                # processed_bases.append(f"INS_{inserted_seq.upper()}")
                processed_bases.append("I")
                i += len(match.group(1)) + num_bases

        # Detectar deleciones (`-nX`)
        elif char == "-":
            match = re.match(r"\-(\d+)", base_calls[i:])
            if match:
                num_bases = int(match.group(1))
                deleted_seq = base_calls[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]
                processed_bases.append("D")
                # processed_bases.append(f"DEL_{deleted_seq.upper()}")
                i += len(match.group(1)) + num_bases

        i += 1

    if len(passing_quality_scores) == 0:
         allele_counts = { 'A': 0, 'C': 0, 'T': 0, 'G': 0, 'DEL': 0, 'INS': 0, "Avg_Qual": None, "depth": 0, "depth_high_q": 0 }
         return allele_counts

    allele_counts = {
        'A': processed_bases.count('A'),
        'C': processed_bases.count('C'),
        'T': processed_bases.count('T'),
        'G': processed_bases.count('G'),
        'DEL': processed_bases.count('D'),
        'INS': processed_bases.count('I')
    }

    # Contar inserciones y deleciones específicas
    for item in processed_bases:
        if item.startswith("INS_"):
            allele_counts[item] = allele_counts.get(item, 0) + 1
        if item.startswith("DEL_"):
            allele_counts[item] = allele_counts.get(item, 0) + 1

    # Calcular calidad promedio
    # allele_counts["Avg_Qual"] = np.mean(quality_scores) if quality_scores else 0
    allele_counts["Avg_Qual"] = sum(passing_quality_scores) / len(passing_quality_scores) # if quality_scores else 0
    allele_counts["depth"] = row["depth"]
    allele_counts["depth_high_q"] = len(passing_quality_scores)
    allele_counts["bases"] = processed_bases
    allele_counts["original_bases"] = row["bases"]
    allele_counts["ref_base"] = ref_base
    allele_counts["quality_scores"] = quality_scores
    
    return allele_counts


def mismatch(row):
    alleles = [ row[x] for x in ['A', 'C', 'T', 'G', 'DEL', 'INS'] ] #, row.C, row.T, row.G]#, row.DEL, row.INS]
    # print(sum(alleles) != row.depth_high_q)
    return sum(alleles) != row.depth_high_q

___

Para obtener la frecuencia de los alelos alternativo, usamos el campo `DP4` del archivo VCF, con las siguientes componentes:

| Campo   | Descripción                                      |
|---------|------------------------------------------------|
| DP4[0]  | Reads forward para el alelo de referencia.    |
| DP4[1]  | Reads reverse para el alelo de referencia.    |
| DP4[2]  | Reads forward para el alelo alternativo.      |
| DP4[3]  | Reads reverse para el alelo alternativo.      |


- **Input**: matriz con cuentas para cada alelo, donde `NaN` significa que el SNV no tiene _variaciones o cobertura_.
- Ver cobertura para SNVs que no aparecen en los VCFs.
- Filtrar variantes.
- filas:(planta,SNV),columnas:None -> filas:planta,columnas:SNV.

In [None]:
VERBOSE = False
VCF_DIR = "data/genomes/alignments_paired_end_new/"
CACHED_FREQ_PKL = "freq_dataframe_ref2.pkl"
CACHED_FREQ_WIDE_PKL = "freq_dataframe_wide_ref2.pkl"

In [None]:
# assert not os.path.exists(CACHED_FREQ_PKL), f"Skipping this cell as {CACHED_FREQ_PKL} already exists."

batch_mapping = get_genome_metadata(as_dataframe=False)
freq_df, counts_df = process_vcf_folder(vcf_folder=VCF_DIR, as_dataframe=True)
depth_df  = counts_df.map(lambda x: sum(x) if isinstance(x, tuple) else x)

freq_df = ( freq_df
    .pipe(display_df, text="Pre-variant filtering", disable=not VERBOSE)
    .pipe(filter_variants)
    .pipe(display_df, text="Filtered variants", disable=not VERBOSE)
    .melt(ignore_index=False)
    .pipe(display_df, text="Reshaped", disable=not VERBOSE)
    .reset_index()
    .pipe(display_df, text="Reset index", disable=not VERBOSE)
    .rename({"level_0": "contig", "level_1": "position", "value": "freq"}, axis=1)
    .pipe(display_df, text="Renamed columns", disable=not VERBOSE)
)

# We map the batch to the sample so that we can pair it up with the phenotype values from the other file
samples = freq_df.variable.apply(lambda x: batch_mapping.get(x, (None, None, -1)))
samples_df = pd.DataFrame(samples.to_list(), columns=["treatment", "replica", "generation"])

freq_df = pd.concat([ freq_df, samples_df ], axis=1)
freq_df = freq_df.query("generation != -1")
freq_df.replica = freq_df.replica.astype(int)
freq_df = freq_df[~freq_df.freq.isna()]

pkl.dump(freq_df, open(CACHED_FREQ_PKL, "wb"))

# assert not os.path.exists(CACHED_FREQ_PKL), f"Skipping this cell as {CACHED_FREQ_PKL} already exists."

def merge_contig_and_position(df): 
    return df.assign(variant_id=df[["contig", "position"]].apply(tuple, axis=1))


freq_df = ( 
    pkl.load(open(CACHED_FREQ_PKL, "rb"))
    .merge_contig_and_position
    .drop(["contig", "position"], axis=1)
    .pivot(
        columns=['treatment', 'replica', 'generation'], 
        index=['variant_id'], 
        values='freq'
    )
    .astype(float)
)

pkl.dump(freq_df, open(CACHED_FREQ_WIDE_PKL, "wb"))

In [None]:
def merge_contig_and_position(df): 
    return df.assign(variant_id=df[["contig", "position"]].apply(tuple, axis=1))


freq_df = ( 
    pkl.load(open(CACHED_FREQ_PKL, "rb"))
    .pipe(merge_contig_and_position)
    .drop(["contig", "position"], axis=1)
    .pivot(
        columns=['treatment', 'replica', 'generation'], 
        index=['variant_id'], 
        values='freq'
    )
    .astype(float)
)

pkl.dump(freq_df, open(CACHED_FREQ_WIDE_PKL, "wb"))

In [None]:
def eliminate_generations_with_many_nans(freq_df, min_n_of_nan):
    return freq_df.loc[ freq_df.isna().__invert__().sum(axis=1) > min_n_of_nan ]


def extract_top_n_variable_variants(freq_df, n):
    top_changing_variants = freq_df.std(axis=1).sort_values(ascending=False).iloc[:n]
    freq_df = freq_df.loc[top_changing_variants.index]
    return freq_df


def join_treatment_rep_gen(freq_df):    
    return freq_df.set_index(freq_df.index.to_frame().apply(tuple, axis=1).rename("id"))

In [None]:
MIN_NUMBER_OF_NON_NAN = 200
TOP_N_VARIANTS = 200

freq_df = (
    pkl.load(open(CACHED_FREQ_WIDE_PKL, "rb"))
    .astype(float)
    .pipe(eliminate_generations_with_many_nans, MIN_NUMBER_OF_NON_NAN)
    .pipe(extract_top_n_variable_variants, TOP_N_VARIANTS)
    .transpose()
    .pipe(join_treatment_rep_gen)
    .sort_index()
)

In [None]:
EXPERIMENTS = [(1, 'K'), (1, 'MS'), (2, 'MS'), (2, 'K'), (3, 'K'), (3, 'MS')]
PHENOTYPES = ["PR_Length", "LR_number", "LR_Density"]

phenotypes_df, _, _, phenotypes_df_red_nb = process_plant_phenotype()
phenotypes_df = adj_phenotypes_for_gwas(phenotypes_df, phenotypes_df_red_nb, PHENOTYPES)

### GWAS: one plant, one data point

In [None]:
gwas_results.set_index(["SNP", "phenotype"])

In [None]:
# SNV = ("PRKV01000002.1", 42521)
# SNV = ("PRKV01000002.1", 496961)
SNV = ("PRKV01000009.1", 124951)

shuffled_results = []

for i in range(1000):

    snp_freq_df = freq_df[[SNV]]
    shuffled_values = snp_freq_df.values
    np.random.shuffle(shuffled_values)
    freq_shuffled_df = pd.DataFrame(shuffled_values.reshape(snp_freq_df.shape), index=snp_freq_df.index, columns=snp_freq_df.columns)
    
    all_shuffled_data = pd.merge(phenotypes_df, freq_shuffled_df, left_index=True, right_index=True)
    snps = freq_shuffled_df.columns.to_list()
    gwas_shuffled_results = run_gwas(all_shuffled_data, snps, phenotypes=PHENOTYPES)
    shuffled_results.append(gwas_shuffled_results)

shuffled_results = pd.concat(shuffled_results)

# shuffled_results.query("phenotype == 'PR_Length'").p_value.hist(bins=100);
# shuffled_results.query("phenotype == 'PR_Length'").p_value.apply(np.log10).hist(bins=50);
shuffled_results.query("phenotype == 'LR_number'").p_value.apply(lambda x: -np.log10(x)).hist(bins=50);
# shuffled_results.query("phenotype == 'LR_Density'").p_value.apply(np.log10).hist(bins=50);

In [None]:
all_data = pd.merge(phenotypes_df, freq_df, left_index=True, right_index=True)
snps = freq_df.columns.to_list()

In [None]:
gwas_results = run_gwas(all_data, snps, phenotypes=PHENOTYPES)
# gwas_results.shape

In [None]:
qqplot(gwas_results.query("phenotype == 'PR_Length'").p_value)
qqplot(gwas_results.query("phenotype == 'LR_number'").p_value)
qqplot(gwas_results.query("phenotype == 'LR_Density'").p_value)

In [None]:
freq_df

In [None]:
all_data

In [None]:
snps = gwas_results.SNP.iloc[3]
all_data[snps].value_counts()

In [None]:
gwas_results

In [None]:
def permutation_test(df, num_permutations=1000):
    """
    Realiza un test de permutación restringida en el GWAS.
    - Permuta phenotype_diff dentro de cada batch
    - Calcula una distribución nula de coeficientes para cada genotipo
    - Retorna los valores p empíricos
    """
    observed_results = run_gwas(df)
    permuted_coefs = {genotype: [] for genotype in observed_results['genotype']}
    
    for _ in range(num_permutations):
        permuted_df = df.copy()
        
        for batch in permuted_df['batch_id'].unique():
            subset = permuted_df[permuted_df['batch_id'] == batch].copy()
            permuted_df.loc[subset.index, 'phenotype_diff'] = np.random.permutation(subset['phenotype_diff'].values)
        
        permuted_results = run_gwas(permuted_df)
        for genotype, coef in zip(permuted_results['genotype'], permuted_results['coef']):
            permuted_coefs[genotype].append(coef)
    
    # Calcular valores p empíricos
    empirical_p_values = []
    for _, row in observed_results.iterrows():
        genotype = row['genotype']
        empirical_p = (100 - percentileofscore(permuted_coefs[genotype], row['coef'])) / 100
        empirical_p_values.append(empirical_p)
    
    observed_results['empirical_p'] = empirical_p_values
    return observed_results

# Ejemplo de uso
# df = pd.read_csv("datos_gwas.csv")  # Cargar los datos reales
# results = permutation_test(df, num_permutations=1000)
# results.to_csv("gwas_permutation_results.csv", index=False)

In [None]:
@interact
def select_association(index=widgets.IntSlider(min=0,max=100)):
    
    SNP       = gwas_results.SNP.iloc[index]
    phenotype_name = gwas_results.phenotype.iloc[index]
    p_value   = gwas_results.p_value.iloc[index]
    
    genotype = all_data[SNP]
    phenotype = all_data[phenotype_name]
    n_points = genotype.isna().__invert__().sum()

    plt.figure(figsize=(15, 5))
    plt.scatter(genotype, phenotype);
    plt.title(f"{SNP}\n{phenotype_name} ({p_value:.1e}, {n_points})");

___

In [None]:
# gff = "/home/rodrigo/01_repos/plant-microbiota-interaction/data/genomes/reference_BM_TG1E1/annotations/annotations2.gff3"
gff = "/home/rodrigo/01_repos/plant-microbiota-interaction/data/genomes/reference_BM_TG1E1/annotations/annotations.gff3"
# gff = "/home/rodrigo/01_repos/plant-microbiota-interaction/data/genomes/reference/annotations_translated_from_tg1e1.gff"
gff_data = read_annotation_data(gff)

In [None]:
# gwas_results.SNP

In [None]:
gwas_results["annotation"] = annotate_results(gwas_results, gff_data.query("type == 'CDS'"))

In [None]:
def compute_distance(signed_dist1, signed_dist2):

    if signed_dist1 * signed_dist2 < 0:
        return 0
    else:
        return min(abs(signed_dist1), abs(signed_dist2))


def get_closest_gene(df):

    dist_to_start = df.start - df.position
    dist_to_end   = df.end   - df.position
    pp = pd.DataFrame([dist_to_start, dist_to_end]).T
    
    min_index = pp.apply(lambda row: compute_distance(row[0], row[1]), axis=1).argmin()
    return df.iloc[min_index]


In [None]:
annotated_gwas_same_contig = pd.merge(gwas_results, gff_data.query("type == 'CDS'"), left_on="contig", right_on="seqid")
mapped_genes = annotated_gwas_same_contig.groupby("SNP").apply(get_closest_gene, include_groups=False)# .attributes

In [None]:
gwas_results.drop(["annotation", "SNP", "beta", "r_squared"], axis=1)[["contig", "position", "phenotype", "p_value"]].to_csv("gwas_results.csv", index=False)

In [None]:
mapped_genes.reset_index().apply(lambda x: [x.SNP, x.start, x.end, x.strand, x.attributes['product']], axis=1)

In [None]:
kk = pd.DataFrame(mapped_genes.reset_index().apply(lambda x: [x.SNP[0], x.SNP[1], x.start, x.end, x.strand, x.attributes['product']], axis=1))

In [None]:
kk.to_csv("gene_annotations.csv", index=False)

In [None]:
mapped_genes

In [None]:
mapped_genes.iloc[0]

In [None]:
mapped_genes.iloc[6]

In [None]:
# manhattan_static(gwas_results)
manhattan_interactive(gwas_results)

In [None]:
# results_df.phenotype.unique()
# results_per_pheno_df = results_df.query("phenotype == 'LR_Density'")
# results_per_pheno_df = results_per_pheno_df.apply(lambda x: str(x))

In [None]:
from Bio import SeqIO

def get_contig_range(input_fasta, contig_name, start=None, end=None):
    """
    Retrieves a specific contig (and optionally a range within it) as a string.

    Parameters:
    - input_fasta: str, path to the input FASTA file.
    - contig_name: str, the name of the contig to query.
    - start: int or None, start position (1-based, inclusive). Default is None for full contig.
    - end: int or None, end position (1-based, inclusive). Default is None for full contig.

    Returns:
    - str: The extracted sequence as a string.
    - None: If the contig is not found.
    """
    for record in SeqIO.parse(input_fasta, "fasta"):
        if record.id == contig_name:
            if start is not None and end is not None:
                return str(record.seq[start - 1:end])
            return str(record.seq)
    print(f"Contig '{contig_name}' not found in {input_fasta}")
    return None

In [None]:
# Example usage
input_file = "data/genomes/reference_2/full_sequence.fasta"
contig = "PRKV01000004.1"
start_position = 394240 - 200
end_position = 394240 + 200

get_contig_range(input_file, contig, start_position, end_position)