In [1]:
import os
import re
import pickle as pkl

import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib.pyplot as plt

from collections import Counter

from ipywidgets import (
    interact,
    IntSlider,
    RadioButtons
)

from genome_helpers import (
    get_genome_metadata
)

In [2]:
def _process_bases(row):
    
    base_calls = row["bases"]
    quality_scores = phred_quality(row["qual"])
    ref_base = row["ref"].upper()

    base_calls = base_calls.replace("$", "").replace("^]", "").replace("^I", "")
    
    processed_bases = []
    passing_quality_scores = []
    
    i = 0
    while i < row["depth"]:
        char = base_calls[i]

        # Filtrar por calidad (descartar bases con calidad < 20)
        if quality_scores[i] < 20:
            i += 1
            continue

        passing_quality_scores.append(quality_scores[i])
        
        # Reemplazar referencia
        if char in ".,":  
            processed_bases.append(ref_base)

        # Contar bases normales
        elif char.upper() in "ACTG":
            processed_bases.append(char.upper())

        # Contar deleciones en la referencia (`*`)
        elif char == "*":
            processed_bases.append("D")

        # Detectar inserciones (`+nX`)
        elif char == "+":
            match = re.match(r"\+(\d+)", base_calls[i:])
            if match:
                num_bases = int(match.group(1))
                inserted_seq = base_calls[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]
                # processed_bases.append(f"INS_{inserted_seq.upper()}")
                processed_bases.append("I")
                i += len(match.group(1)) + num_bases

        # Detectar deleciones (`-nX`)
        elif char == "-":
            match = re.match(r"\-(\d+)", base_calls[i:])
            if match:
                num_bases = int(match.group(1))
                deleted_seq = base_calls[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]
                processed_bases.append("D")
                # processed_bases.append(f"DEL_{deleted_seq.upper()}")
                i += len(match.group(1)) + num_bases

        i += 1

    if len(passing_quality_scores) == 0:
         allele_counts = { 'A': 0, 'C': 0, 'T': 0, 'G': 0, 'DEL': 0, 'INS': 0, "Avg_Qual": None, "depth": 0, "depth_high_q": 0 }
         return allele_counts

    allele_counts = {
        'A': processed_bases.count('A'),
        'C': processed_bases.count('C'),
        'T': processed_bases.count('T'),
        'G': processed_bases.count('G'),
        'DEL': processed_bases.count('D'),
        'INS': processed_bases.count('I')
    }

    # Contar inserciones y deleciones específicas
    for item in processed_bases:
        if item.startswith("INS_"):
            allele_counts[item] = allele_counts.get(item, 0) + 1
        if item.startswith("DEL_"):
            allele_counts[item] = allele_counts.get(item, 0) + 1

    # Calcular calidad promedio
    # allele_counts["Avg_Qual"] = np.mean(quality_scores) if quality_scores else 0
    allele_counts["Avg_Qual"] = sum(passing_quality_scores) / len(passing_quality_scores) # if quality_scores else 0
    allele_counts["depth"] = row["depth"]
    allele_counts["depth_high_q"] = len(passing_quality_scores)
    allele_counts["bases"] = processed_bases
    allele_counts["original_bases"] = row["bases"]
    allele_counts["ref_base"] = ref_base
    allele_counts["quality_scores"] = quality_scores
    
    return allele_counts


def mismatch(row):
    alleles = [ row[x] for x in ['A', 'C', 'T', 'G', 'DEL', 'INS'] ] #, row.C, row.T, row.G]#, row.DEL, row.INS]
    # print(sum(alleles) != row.depth_high_q)
    return sum(alleles) != row.depth_high_q

In [3]:
def phred_quality(qual_string):
    return [ord(q) - 33 for q in qual_string]


def count_alleles(row):

    base_calls = row["bases"]    
    print(f"{len(base_calls)=}")

    quality_scores = phred_quality(row["qual"])
    print(f"{quality_scores=}")
    print(f"{len(quality_scores)=}")
    
    ref_base = row["ref"].upper()
    print(f"{ref_base=}")

    base_calls = re.sub(pattern="\^.", repl="", string=base_calls)
    
    # base_calls.replace('$', "").replace('^]', "").replace('^I', "")

    n_bases = len(base_calls.replace("-", "").replace("+", ""))

    # Reemplazar '.' y ',' con la base de referencia
    base_calls = base_calls.replace('.', ref_base).replace(',', ref_base)
    print(f"{base_calls=}")
    print(f"{len(base_calls)=}")
    
    filtered_bases = "".join([base_calls[i] for i in range(n_bases) if quality_scores[i] >= 20])
    print(filtered_bases)
    filtered_quality_scores = sum([quality_scores[i] for i in range(n_bases) if quality_scores[i] >= 20]) /n_bases
    
    counts = {
        'A': filtered_bases.count('A'),
        'C': filtered_bases.count('C'),
        'T': filtered_bases.count('T'),
        'G': filtered_bases.count('G'),
        'DEL': filtered_bases.count('*')  # Conteo de deleciones
    }

    # Buscar inserciones y deleciones en la secuencia de bases
    insertions = re.findall(r'\+(\d+)([ACGTNacgtn]+)', filtered_bases)
    deletions = re.findall(r'\-(\d+)([ACGTNacgtn]+)', filtered_bases)

    for size, seq in insertions:
        key = f"INS_{ref_base}{seq.upper()}"
        counts[key] = counts.get(key, 0) + 1

    for size, seq in deletions:
        key = f"DEL_{seq.upper()}"
        counts[key] = counts.get(key, 0) + 1

    
    counts["Avg_Qual"] = filtered_quality_scores # np.mean(quality_scores) if quality_scores else 0
    
    return counts


def parse_bases(base_string, ref_base):

    i = 0
    bases_list = []
    
    while i < len(base_string):
        
        char = base_string[i]

        # Reemplazar referencia
        if char in ".,":  
            current_base = ref_base
            bases_list.append(current_base)

        # Contar bases normales
        elif char.upper() in "ACTG":
            current_base = char.upper()
            bases_list.append(current_base)

        # Contar deleciones en la referencia (`*`)
        elif char == "*":
            bases_list.append("DEL")

        # Detectar inserciones (`+nX`) y asociarlas correctamente
        elif char == "+":
            match = re.match(r"\+(\d+)", base_string[i:])
            if match:
                num_bases = int(match.group(1))
                inserted_seq = base_string[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]

                # Asociar la inserción a la base previa (última base agregada)
                if bases_list:
                    bases_list[-1] = f"{bases_list[-1]}_INS_{inserted_seq.upper()}"
                else:
                    bases_list.append(f"INS_{inserted_seq.upper()}")  # Caso extremo: inserción sin base previa

                i += len(match.group(1)) + num_bases  # Salteamos la inserción en el string

        # Detectar deleciones (`-nX`)
        elif char == "-":
            match = re.match(r"\-(\d+)", base_string[i:])
            if match:
                num_bases = int(match.group(1))
                deleted_seq = base_string[i+len(match.group(1))+1:i+len(match.group(1))+1+num_bases]

                # Asociar la deleción a la base previa (última base agregada)
                if bases_list:
                    bases_list[-1] = f"{bases_list[-1]}_DEL_{deleted_seq.upper()}"
                else:
                    bases_list.append(f"DEL_{deleted_seq.upper()}")

                i += len(match.group(1)) + num_bases  # Salteamos la deleción en el string

        i += 1

    return bases_list


def process_bases(row):
    
    base_calls = row["bases"]
    quality_scores = phred_quality(row["qual"])
    ref_base = row["ref"].upper()

    base_calls = base_calls.replace("$", "").replace("^]", "").replace("^I", "")

    parsed_bases = parse_bases(base_calls, ref_base)

    return [Counter(parsed_bases), quality_scores]


def process_folder(folder="data/genomes/alignments_paired_end"):
    
    for raiz, carpetas, archivos in os.walk(folder):
        for archivo in archivos:        
    
            if not ("counts" in archivo and archivo.endswith("txt")):
                continue
    
            ruta_completa = os.path.join(raiz, archivo)
            batch = os.path.basename(ruta_completa).split("__")[0]        
    
            df = pd.read_csv(ruta_completa, sep="\t", header=None)
            df.columns = ["contig", "position", "ref", "depth", "bases", "qual"]
            df = pd.concat([df, df.apply(lambda x: process_bases(x)[0], axis=1), df.apply(lambda x: process_bases(x)[1], axis=1)], axis=1)
            df.columns = df.columns[:-2].to_list() + ["count", "quality"]
            df = df.drop(["qual"], axis=1)
            df = df.assign(sample=batch)
            
            (pp := locals().get("pp") or []).append(df)
    
    pp = pd.concat(pp)[["contig", "position", "sample", "count"]].pivot(index=["contig", "position"], columns="sample", values="count")
    return pp

In [4]:
def freq_from_counts(allele_dict):
    
    assert isinstance(allele_dict, dict) or np.isnan(allele_dict)

    if not isinstance(allele_dict, dict) and np.isnan(allele_dict):
        return dict(depth=0, A=0, T=0, G=0, C=0)
    else:
        depth = sum(allele_dict.values())
        return dict(
            depth=depth, 
            A=allele_dict.get("A", 0), 
            T=allele_dict.get("T", 0), 
            G=allele_dict.get("G", 0), 
            C=allele_dict.get("C", 0))
    

def plot_allele_freqs(snp_evolution_df):

    plt.figure(figsize=(20, 8))

    for allele in ["A", "C", "T", "G"]:
        plt.plot(snp_evolution_df.index, snp_evolution_df[allele], label=f'Alelo {allele}')

    if "depth" in snp_evolution_df.columns:
        plt.plot(snp_evolution_df.index, snp_evolution_df.depth, label=f'Cobertura')

    plt.xlabel('Generaciones')
    plt.ylabel('Frecuencia alélica')
    plt.title('Distribución de alelos a lo largo de generaciones')
    plt.legend()
    
    plt.show()

In [5]:
freq_df = ( (df := process_folder())
    .map(freq_from_counts)
    .melt(ignore_index=False) )

batch_mapping = get_genome_metadata(as_dataframe=False)
samples = freq_df['sample'].apply(lambda x: batch_mapping.get(x, (x, -1, -1)))
samples_df = pd.DataFrame(samples.to_list(), columns=["treatment", "replica", "generation"])

freq_df = pd.concat([samples_df, freq_df.reset_index()], axis=1)

In [6]:
MINIMUM_DEPTH = 50

@interact
def interactive_allele_evolution_plot(i=IntSlider(min=0, max=100), freq_or_count=RadioButtons(options=["cuentas", "frecuencia"])):

    SNPs = freq_df.drop_duplicates(subset=["contig", "position"])[["contig", "position"]]
    contig, position = SNPs.contig.iloc[i], SNPs.position.iloc[i]

    TREATMENT, REPLICA = "MS", 3

    snp_evolution = ( freq_df
        .query("treatment == @TREATMENT and replica == @REPLICA")
        .query("contig == @contig and position == @position")
        .sort_values("generation") )
    
    snp_count_evol = snp_evolution.value.apply(lambda x: [x["A"], x["C"], x["T"], x["G"], x['depth']]).set_axis(snp_evolution.generation)
    snp_count_evol = pd.DataFrame(snp_count_evol.apply(lambda x: list(x)).to_list(), columns=["A", "C", "T", "G", "depth"])    
    snp_count_evol = snp_count_evol.query("depth > @MINIMUM_DEPTH")

    snp_freq_evol = ( snp_count_evol
        .assign(freqA=snp_count_evol.A/snp_count_evol.depth, freqC=snp_count_evol.C/snp_count_evol.depth, freqT=snp_count_evol['T']/snp_count_evol.depth, freqG=snp_count_evol.G/snp_count_evol.depth)
        .drop(["A", "C", "T", "G", "depth"], axis=1)
        .rename({ f"freq{a}": a for a in ["A", "C", "T", "G"] }, axis=1))
    
    plot_allele_freqs(snp_freq_evol if freq_or_count == "frecuencia" else snp_count_evol)

interactive(children=(IntSlider(value=0, description='i'), RadioButtons(description='freq_or_count', options=(…

In [7]:
def get_freq(count_dict):
    
    assert 'depth' in count_dict
    
    count_dict = count_dict.copy() 
    depth = count_dict.pop('depth')
    major_allele_count = max(count_dict.values())
    if depth < 20:
        return np.nan
    
    return major_allele_count / depth

In [8]:
def merge_contig_and_position(df): 
    return df.assign(variant_id=df[["contig", "position"]].apply(tuple, axis=1))

major_allele_freq_df = ( 
    freq_df.assign(value=freq_df.value.apply(get_freq))
    .query("value.notna()")
    .assign(replica=lambda x: x['replica'].astype(int))
    .drop("sample", axis=1)
    .pivot(
        columns=["treatment", "replica", "generation"], 
        index=["contig", "position"], values="value")
    .reset_index()
    .pipe(merge_contig_and_position)
    .set_index('variant_id')
    .drop(["contig", "position"], axis=1) )

major_allele_freq_df

  .drop(["contig", "position"], axis=1) )


treatment,K,MS,MS,MS,K,K,K,K,MS,MS,...,K,MS,MS,K,K,MS,MS,K,K,K
replica,1,3,3,2,2,2,1,1,3,2,...,2,3,3,2,2,3,3,2,3,2
generation,39,67,34,27,43,14,38,63,8,4,...,57,59,17,27,46,43,50,36,7,77
variant_id,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3
"(contig000001, 146)",1.000000,0.979592,0.795918,0.945312,1.000000,1.000000,0.725490,0.678571,0.851064,1.000000,...,1.000000,0.533632,1.000000,0.930233,1.000000,0.813953,0.888889,1.000000,1.000000,1.000000
"(contig000001, 227)",1.000000,0.989474,0.858824,0.985849,,1.000000,0.893204,0.840426,0.885714,1.000000,...,1.000000,0.672840,1.000000,0.937984,1.000000,0.857143,0.979592,1.000000,0.996000,1.000000
"(contig000001, 5306)",1.000000,0.651316,1.000000,0.642857,0.565217,0.958333,1.000000,0.535088,0.581818,1.000000,...,1.000000,1.000000,1.000000,0.523529,0.837696,0.847222,1.000000,1.000000,1.000000,0.854962
"(contig000001, 91539)",0.591176,0.516667,0.500000,0.550063,0.507463,0.522642,0.733607,0.503704,0.548209,0.545098,...,0.558763,0.684380,0.552486,0.516393,0.590840,0.519231,0.847826,0.541237,1.000000,0.507331
"(contig000001, 91557)",0.607955,0.531621,0.517073,0.531100,0.507937,0.543119,0.736842,0.523810,0.548052,0.558226,...,0.558577,0.670455,0.542328,0.523810,0.602056,0.511848,0.866667,0.522843,0.998088,0.512676
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"(contig000029, 696)",1.000000,0.998117,1.000000,1.000000,1.000000,1.000000,1.000000,0.992620,1.000000,1.000000,...,0.997881,1.000000,1.000000,1.000000,0.998752,1.000000,1.000000,1.000000,0.998371,1.000000
"(contig000032, 8535)",0.994152,0.969027,0.176166,0.997727,1.000000,0.960000,1.000000,0.986755,0.400000,1.000000,...,1.000000,1.000000,1.000000,1.000000,1.000000,1.000000,0.988889,0.989796,0.998294,1.000000
"(contig000038, 58)",,0.863636,0.758621,0.918367,1.000000,1.000000,0.515152,0.681818,0.500000,1.000000,...,1.000000,0.644628,,,1.000000,,,,1.000000,
"(contig000038, 105)",,0.180000,0.842105,0.177419,0.000000,0.178571,0.659091,0.456790,0.629630,0.172414,...,0.142857,0.503205,,0.785714,0.086957,0.695652,,,0.095745,


In [9]:
FREQ_WIDE_PKL = "freq_dataframe_wide_ref1.pkl"
# assert not os.path.exists(FREQ_WIDE_PKL), f"{FREQ_WIDE_PKL} already exists, not overwriting it."

pkl.dump(major_allele_freq_df, open(FREQ_WIDE_PKL, "wb"))