In [None]:
import os
import subprocess
import pandas as pd
import numpy as np
from glob import glob
import secrets

import seaborn as sns
import matplotlib.pyplot as plt

from jacksonii_analyses import vcf_parser, clustering

In [None]:
os.makedirs("../data/var/admixture", exist_ok=True)
os.makedirs("../data/figs", exist_ok=True)

In [None]:
def run_admixture(
    input_bed: str,
    k: int = 2,
):
    output_prefix = f"{os.path.basename(input_bed).split(".")[0]}.{k}"
    result = subprocess.run([
        "admixture",
        "--cv",
        input_bed,
        str(k),
        "-j4",
        f"--seed={secrets.token_hex(8)}",
    ], capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Error running admixture: {result.stderr}")
        return
    with open(f"{output_prefix}.log", "w") as f:
        f.write(result.stdout)

groups = [
    'A. jacksonii', 
    'A. sp. T31', 
    'A. sp. jack6', 
    'A. sp. jack5',
    'A. sp. jack3', 
    'A. sp. jack2', 
    'A. sp. jack1', 
    'A. sp. F11',
]

palette = [
    "#1f77b4",  # blue
    "#ff7f0e",  # orange
    "#2ca02c",  # green
    "#d62728",  # red
    "#9467bd",  # purple
    "#8c564b",  # brown
    "#e377c2",  # pink
    "#7f7f7f",  # gray
]

map_colors = dict(zip(groups, palette))

Create a chromosome mapping file for plink.

In [None]:
vcf_path = "../data/var/filtered_variants.vcf.gz"
chr_pos_data = vcf_parser.vcf_to_chr_pos_df(vcf_path)
chromosome_mapping = pd.DataFrame({
    "chrom": chr_pos_data["chrom"].unique(),
})
chromosome_mapping = chromosome_mapping.reset_index().rename(columns={"index": "id"})
chromosome_mapping["id"] = (chromosome_mapping["id"] + 1).astype(str)
chromosome_mapping[["chrom", "id"]].to_csv(
    "../data/var/admixture/chromosome_mapping.txt",
    sep="\t",
    index=False,
    header=False,
)
max_chrom_id = chromosome_mapping["id"].astype(int).max()

In [None]:
subprocess.run([
    "bcftools", 
    "annotate", 
    "--rename-chrs", 
    "../data/var/admixture/chromosome_mapping.txt", 
    vcf_path, 
    "-o", 
    "../data/var/admixture/filtered_variants_chrom_renamed.vcf.gz", 
    "-Oz",
], check=True)

Run plink to convert the VCF file to `admixture` format. (`.bim`, `.bed`, `.fam`)

In [None]:
os.chdir("/workspace/notebooks")
subprocess.run([
    "plink",
    "--vcf", "../data/var/admixture/filtered_variants_chrom_renamed.vcf.gz",
    "--make-bed",
    "--biallelic-only",
    "--snps-only",
    "--chr-set", f"{max_chrom_id}",
    "--out", "../data/var/admixture/filtered",
    "--allow-extra-chr",
], check=True)

Run `admixture` in cross-validation mode. (`n-folds=5`)

In [None]:
if "admixture" not in os.getcwd():
    os.chdir("../data/var/admixture")
for iteration in range(1, 6):
    os.makedirs(f"iteration_{iteration}", exist_ok=True)
    os.chdir(f"iteration_{iteration}")
    for k in range(2, 11):
        run_admixture("../filtered.bed", k=k)
    os.chdir("..")
os.chdir("/workspace/notebooks")

In [None]:
samples_file = "../data/samples/populations.txt"
populations = clustering.read_populations(samples_file)
populations

In [None]:
admixture_logfiles = glob("../data/var/admixture/iteration_*/*.log")
admixture_logfiles.sort(key=lambda x: int(x.split(".")[3]))  # Sort by k value
k_values = [int(f.split(".")[3]) for f in admixture_logfiles]
cv_values = []
iteration_values = [int(f.split("/")[4].split("_")[1]) for f in admixture_logfiles]
for logfile in admixture_logfiles:
    with open(logfile, "r") as f:
        for line in f:
            if "CV error" in line:
                cv_values.append(float(line.split()[-1]))
                break
cv_df = pd.DataFrame({
    "k": k_values,
    "CV error": cv_values,
    "iteration": iteration_values,
})
cv_df.head()

In [None]:
plt.figure(figsize=(8, 6))
sns.boxplot(
    x="k",
    y="CV error",
    data=cv_df,
    dodge=True,
    width=0.4,
    color="lightgray",
)
sns.pointplot(
    x="k",
    y="CV error",
    data=cv_df,
    color="black",
    markers="o",
    errorbar=None,
    dodge=True
)
plt.xlabel("Number of clusters (K)")
plt.ylabel("Cross-validation error")
plt.title("ADMIXTURE Cross-validation Error by K (6 iterations)")
plt.savefig("../data/figs/admixture_cv_error.svg")
plt.show()

In [None]:
best_k_df = cv_df.loc[cv_df["CV error"].idxmin()]
best_k = int(best_k_df["k"])
best_iter = int(best_k_df["iteration"])
print(f"Best K: {best_k} in iteration {best_iter} with CV error: {best_k_df['CV error']}")

In [None]:
best_admixture_matrix = f"filtered.{best_k}.Q"
best_admixture_Q_file = os.path.join(f"../data/var/admixture/iteration_{best_iter}", best_admixture_matrix)
admixture_fam_file = os.path.join("../data/var/admixture", "filtered.fam")

q = pd.read_csv(best_admixture_Q_file, sep=" ", header=None)
fam = pd.read_csv(admixture_fam_file, sep=" ", header=None)

q["sample"] = fam[0]

# Find max cluster assignment for each sample
q["max_cluster"] = q.iloc[:, :-1].idxmax(axis=1)
q["max_value"] = q.iloc[:, :-2].max(axis=1)

# Sort by cluster and then by max ancestry value
q_ordered = q.sort_values(["max_cluster", "max_value"], ascending=[True, False])

# Melt for seaborn
q_melt = q_ordered.drop(columns=["max_value"]).melt(
    id_vars=["sample", "max_cluster"], var_name="Cluster", value_name="Ancestry"
)

q_melt["sample"] = pd.Categorical(q_melt["sample"], categories=q_ordered["sample"], ordered=True)
q_melt = q_melt.merge(populations, on="sample", how="left")

In [None]:
clusters = [col for col in q_ordered.columns if isinstance(col, int) or (isinstance(col, str) and col.isdigit())]
samples = q_ordered["sample"].tolist()
n_clusters = len(clusters)

# Choose a grayscale-friendly palette and swap cluster 5 and 6 colors
colors = sns.color_palette("colorblind", n_clusters)
greys = sns.color_palette("light:#000000", n_colors=n_clusters)
if n_clusters >= 6:
    colors[4], colors[5] = colors[5], colors[4]
    greys[4], greys[5] = greys[5], greys[4]

# Map cluster index to most common populations_clean value
cluster_names = []
for i in range(n_clusters):
    # For each cluster, find samples where this cluster is max
    cluster_samples = q_melt[q_melt["Cluster"] == i]
    # Get the most common populations_clean value for these samples
    if "populations_clean" in cluster_samples.columns and not cluster_samples["populations_clean"].isnull().all():
        name = (
            cluster_samples.loc[cluster_samples["Ancestry"] > 0.5, "populations_clean"]
            .mode()
            .iloc[0]
            if not cluster_samples.loc[cluster_samples["Ancestry"] > 0.5, "populations_clean"].empty
            else f"Cluster {i+1}"
        )
    else:
        name = f"Cluster {i+1}"
    cluster_names.append(name)
if "A. sp. jack6" in cluster_names and "A. sp. jack5" not in cluster_names:
    cluster_names[cluster_names.index("A. sp. jack6")] = "A. sp. jack5/6"

q_named = q_ordered.rename(columns={i: cluster_names[i] for i in range(len(cluster_names))})
q_named = q_named.merge(populations, on="sample", how="left")
q_named = q_named.sort_values(by=["populations_clean", "sample"]).reset_index(drop=True)
data = q_named[sorted(cluster_names)].to_numpy()


In [None]:
map_colors["A. sp. jack5/6"] = map_colors["A. sp. jack6"]
cluster_colors_df = pd.DataFrame({
    "populations": cluster_names,
    "color_greys": greys,
    "color": [map_colors.get(pop, "#cccccc") for pop in cluster_names],
})
cluster_colors_df = cluster_colors_df.sort_values(by="populations").reset_index(drop=True)

In [None]:
# Plot stacked bars
fig, ax = plt.subplots(figsize=(12, 4))
bottom = np.zeros(len(samples))
for i, row in cluster_colors_df.iterrows():
    ax.bar(
        samples,
        data[:, i],
        bottom=bottom,
        color=row["color"],
        label=row["populations"],
        width=1.0,
        edgecolor="none"
    )
    bottom += data[:, i]

ax.set_xlabel("Individuals (sorted by cluster)")
ax.set_ylabel("Ancestry Proportion")
ax.set_title(f"ADMIXTURE Q matrix (K={best_k})")
ax.set_xticks(range(len(samples)))
ax.set_xticklabels(samples, rotation=45, ha="right", fontsize=7)
ax.legend(title="Cluster", bbox_to_anchor=(1.01, 1), loc="upper left")
plt.tight_layout()
#plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove whitespace inside plot
fig.savefig("../data/figs/admixture_qmatrix_color.svg")
fig.savefig("../data/figs/admixture_qmatrix_color.pdf")
plt.show()

In [None]:
# Plot stacked bars
fig, ax = plt.subplots(figsize=(12, 4))
bottom = np.zeros(len(samples))
for i, row in cluster_colors_df.iterrows():
    ax.bar(
        samples,
        data[:, i],
        bottom=bottom,
        color=row["color_greys"],
        label=row["populations"],
        width=1.0,
        edgecolor="none"
    )
    bottom += data[:, i]

ax.set_xlabel("Individuals (sorted by cluster)")
ax.set_ylabel("Ancestry Proportion")
ax.set_title(f"ADMIXTURE Q matrix (K={best_k})")
ax.set_xticks(range(len(samples)))
ax.set_xticklabels(samples, rotation=45, ha="right", fontsize=7)
ax.legend(title="Cluster", bbox_to_anchor=(1.01, 1), loc="upper left")
plt.tight_layout()
#plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove whitespace inside plot
fig.savefig("../data/figs/admixture_qmatrix_grey.svg")
fig.savefig("../data/figs/admixture_qmatrix_grey.pdf")
plt.show()

In [None]:
# print admixed individuals with at least two clusters above threshold
admixture_threshold = 0.1
admixture_df = q_named[q_named[cluster_names].max(axis=1) > admixture_threshold]
admixture_df = admixture_df[admixture_df[cluster_names].apply(lambda x: (x > admixture_threshold).sum() >= 2, axis=1)]
admixture_df = admixture_df[["sample", "populations_clean"] + cluster_names]
admixture_df = admixture_df.sort_values(by="populations_clean").reset_index(drop=True)
admixture_df.to_csv(
    "../data/var/admixture/admixed_individuals.csv",
    index=False,
)
print(f"Admixed individuals with at least two clusters above threshold: {admixture_threshold}")
admixture_df

In [None]:
samples_file = "../data/samples/SraRunTable.csv"
full_sample_data = pd.read_csv(samples_file)
admixed_sample_set = full_sample_data[full_sample_data["Run"].isin(admixture_df["sample"])]
admixed_sample_set[["Run", "strain", "collection_method", "Collection_Date", "geo_loc_name"]]

In [None]:
abbreviations = {
    "USA: West Virginia": "US-WV",
    "Canada: Ontario": "CA-ON",
    "USA: Missouri": "US-MO",
    "USA: Texas": "US-TX",
    "Mexico: Jalisco": "MX-JAL",
    "Mexico: Michoacan": "MX-MICH",
    "Mexico: Nayarit": "MX-NAY",
    "Mexico: Chihuahua": "MX-CHIH",
    "Mexico: Oaxaca": "MX-OAX",
    "Mexico: Guerrero": "MX-GRO",
    "Mexico: Hidalgo": "MX-HGO",
    "Mexico: Chiapas": "MX-CHIS",
    "USA: Tennessee": "US-TN",
    "USA: South Carolina": "US-SC",
    "USA: Indiana": "US-IN",
    "USA: Arkansas": "US-AR",
    "USA: North Carolina": "US-NC",
    "Canada: Qubec": "CA-QC",
    "USA: Florida": "US-FL",
    "USA: Pennsylvania": "US-PA",
    "USA: Connecticut": "US-CT",
    "USA: New York": "US-NY",
    "USA: Massachusetts": "US-MA",
}
short_loc = full_sample_data[["Run", "geo_loc_name"]].copy()
short_loc.loc[:, "geo_loc_name"] = short_loc["geo_loc_name"].apply(lambda x: x.split("\\,")[0] if isinstance(x, str) else x)
short_loc.loc[:, "abbreviated_loc"] = short_loc["geo_loc_name"].apply(
    lambda x: abbreviations[x] if x in abbreviations else x
)
short_loc["sample_loc"] = short_loc["Run"] + " " + short_loc["abbreviated_loc"]
short_loc.rename(columns={"Run": "sample"}, inplace=True)
short_loc.to_csv(
    "../data/samples/samples_short_loc.csv",
    index=False,
    header=True,
)
short_loc