In [None]:
import os
import pysam
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
data = pd.read_excel("./data/genomes/metadata_whole_genome.xlsx", engine='openpyxl')
data = data.drop(["Generation", "rep"], axis=1)
data.loc[data.query('`samples.1` == "RMF27"').index, "samples.1"] = "0.0_-"
sample_details = pd.DataFrame(data['samples.1'].apply(lambda x: x.split("_")[0].split(".") + [x.split("_")[1]]).to_list(), columns=["generation", "rep", "treatment"])
data = pd.concat([data, sample_details], axis=1 )
data.head()

In [None]:
VCF_DIR = "./results/vcf/"
files = os.listdir(VCF_DIR)

In [None]:
data['BGI_ID'] = data['BGI_ID'].astype(str)

In [None]:
def compute_af(record):
    dp4 = record.info['DP4']
    if sum(dp4) == 0:
        return None
    allele_freq = (dp4[2]+dp4[3]) / sum(dp4)
    return allele_freq

In [None]:
variant_dict = {}
variant_freq_dict = {}

for file in files:
    
    bgi_id = file.replace(".vcf.gz", "")

    sample_info = data.loc[data.BGI_ID == bgi_id, ["generation", "rep", "treatment"]]    
    if len(sample_info) == 1:
        sample_info = sample_info.iloc[0].to_list()
        sample_info = tuple(sample_info)
    else:
        continue
        
    try:
        vcf_file = pysam.VariantFile(f"{VCF_DIR}/{file}")
    except:
        print("")
        continue
        
    variant_positions = [(record.chrom, record.pos) for record in vcf_file.fetch()]
    variant_freq = {(record.chrom, record.pos): compute_af(record) for record in vcf_file.fetch()}
    variant_dict.update({sample_info: set(variant_positions)})
    variant_freq_dict.update({sample_info: variant_freq})

Compute length of intersection over intersection of union:

In [None]:
intersection = []

for sample_i, variants_i in tqdm(variant_dict.items()):
    intersection.append([])
    for sample_j, variants_j in variant_dict.items():
        intersection[-1].append(len(variants_j.intersection(variants_i)) / len(variants_j.union(variants_i)))

colnames = ["_".join(k) for k, v in list(variant_dict.items())]
intersection_df = pd.DataFrame(np.array(intersection), columns=colnames)
intersection_df.index = colnames

In [None]:
intersection_df['71_1_MS'].sort_values(ascending=False)

In [None]:
from scipy.cluster.hierarchy import linkage, leaves_list
corr_linkage = linkage(intersection_df, method='average')
idx = leaves_list(corr_linkage)
ordered_matrix = intersection_df.iloc[idx, idx]

In [None]:
ordered_matrix["0_0_-"].sort_values(ascending=False)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
rep = 1
gen_ref = 49
WINDOW_LEN = 60
treatment = "MS"
label_ref = f"{gen_ref}_{rep}_{treatment}"

kk = [ f"{gen}_{rep}_{treatment}" for gen in range(gen_ref-WINDOW_LEN//2, gen_ref+WINDOW_LEN//2)]
kk = sorted(list(set(kk).intersection(ordered_matrix.index)))

fig, ax = plt.subplots(figsize=(20,5))
ax.plot(ordered_matrix.loc[kk, f'{gen_ref}_{rep}_{treatment}'])# .sort_values(ascending=False).head(20)
ax.tick_params(axis='x', labelrotation=60)
ax.axvline(x=label_ref, color='red', linestyle='dashed')# type='---')

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow((np.array(ordered_matrix)))# , annot=False, cmap='coolwarm')
plt.title("Ordered Correlation Matrix by Similarity")
plt.show()

In [None]:
# for record in vcf_file.fetch():
#     print(f"Chromosome: {record.chrom}")
#     print(f"Position: {record.pos}")
#     print(f"Reference Allele: {record.ref}")
#     print(f"Alternative Alleles: {record.alts}")
#     print(f"Quality: {record.qual}")
#     print(f"Info: {dict(record.info)}")

# Analysis by frequency

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact

In [None]:
allele_freqs = pd.DataFrame(variant_freq_dict).melt(ignore_index=False)
allele_freqs.columns = ["generation", "replica", "treatment", "freq"]
allele_freqs = allele_freqs.reset_index()
allele_freqs.generation = allele_freqs.generation.astype(int)

kk = allele_freqs[~allele_freqs.freq.isna()]
kk.apply(lambda row: (row.level_0, row.level_1), axis=1)

In [None]:
variant_count = kk.groupby('generation')[['level_0', 'level_1']].count().reset_index()
valid_generations = variant_count[variant_count['level_0'] <= 1000]['generation']
filtered_df = kk[kk['generation'].isin(valid_generations)]
filtered_df

In [None]:
# filtered_df.set_index(['level_0', 'level_1', 'generation']).pivot(columns=, index, values)
filtered_df = filtered_df.pivot(columns=['replica', 'treatment'], index=['level_0', 'level_1', 'generation'], values='freq').reset_index()

In [None]:
filtered_df.loc[:, ['1', '2', '3']]

In [None]:
# filtered_df.apply(lambda row: (row.level_0[0], row.level_1[0]), axis=1)
# filtered_df.apply(lambda row: (row.level_0, row.level_1), axis=1)
data = [filtered_df.apply(lambda row: (row.level_0[0], row.level_1[0]), axis=1), filtered_df.generation, filtered_df.loc[:, ['1', '2', '3']]]
data = pd.concat(data, axis=1).sort_values("generation")
data.columns = ['Variant', 'Generation'] + data.columns[2:].to_list()

In [None]:
data#.groupby('Variant')['Generation'].nunique()

In [None]:
filtered_variants = data.groupby('Variant')['Generation'].nunique()
filtered_variants = filtered_variants[filtered_variants >= 10].index  # Seleccionar las variantes válidas
filtered_data = data.set_index("Variant").loc[filtered_variants].reset_index()

In [None]:
filtered_data.groupby('Variant')[[('1', 'MS')]].std()

In [None]:
changing_variants = filtered_data.groupby("Variant")[[('1', 'MS')]].std().iloc[:,0].sort_values(ascending=False)[:100].index
top_changing_variants_df = filtered_data.set_index("Variant").loc[changing_variants]
freq_data = top_changing_variants_df.reset_index()

In [None]:
variants_lst = [list(x) for x in top_changing_variants_df.index.unique()]
display_options = [f'{item[0]} - {item[1]}' for item in sorted(variants_lst)]
value_dict = {f'{item[0]} - {item[1]}': item for item in sorted(variants_lst)}

In [None]:
treatment_options = [('1', 'MS'), ('2', 'MS'), ('3', 'MS'), ('1', 'K'), ('2', 'K'), ('3', 'K')]
treatment_display_options = { f"{replica}_{treatment}": (replica, treatment) for replica, treatment in treatment_options }

In [None]:
@interact
def plot_freq_curves(
    variants=widgets.SelectMultiple(options=display_options, value=[display_options[0]]),
    treatments=widgets.SelectMultiple(options=treatment_display_options, value=[treatment_display_options['1_MS']])
):

    plt.figure(figsize=(15, 6))
    
    for variant in variants:
        variant = value_dict[variant]
        variant = tuple(variant)
        for treatment in treatments:        
            subset = freq_data[freq_data['Variant'] == variant]
            plt.plot(subset['Generation'].to_list(), subset[treatment].to_list(), marker='o', label=f'{treatment}')

    
    plt.title(f'Evolución de la frecuencia del alelo alternativo {"("+variants[0]+")" if len(variants)==1 else ""}')
    
    plt.xlabel('Generación')
    plt.ylabel('Frecuencia del alelo alternativo')
    plt.legend(title='Tratamiento', loc='lower right')
    plt.grid(True)
    plt.xlim(0.5)
    plt.ylim(0, 1)
    
    plt.show()
