In [1]:
import pathlib

import pandas as pd
import polars as pl
import tqdm

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
direct_paths = list(pathlib.Path("data/gwas").glob("plink.b_*.glm.linear.zst"))

len(direct_paths)

1238

In [3]:
igwas_paths = list(pathlib.Path("data/igwas").glob("pca_*.tsv.zst"))

igwas_paths

[PosixPath('data/igwas/pca_0.25.tsv.zst'),
 PosixPath('data/igwas/pca_1.0.tsv.zst'),
 PosixPath('data/igwas/pca_0.75.tsv.zst'),
 PosixPath('data/igwas/pca_0.9.tsv.zst'),
 PosixPath('data/igwas/pca_0.1.tsv.zst'),
 PosixPath('data/igwas/pca_0.5.tsv.zst')]

In [4]:
sampled_variant_df = (
    pl.read_csv(direct_paths[0], separator="\t", columns=["ID", "P"])
    .rename({"ID": "variant_id", "P": "p_value"})
    .with_columns(quantile=pl.col("p_value").qcut(10))
    .filter(pl.int_range(0, pl.len()).shuffle(seed=0).over("quantile") < 1000)
)

sampled_variant_df.head(2)

variant_id,p_value,quantile
str,f64,cat
"""1:1467485""",0.949241,"""(0.8991141, in…"
"""1:1472047""",0.297659,"""(0.201307, 0.2…"


In [None]:
for igwas_path in igwas_paths:
    fit_quality = list()
    full_sampled_df = None
    
    fraction = float(igwas_path.stem.replace(".tsv", "").replace("pca_", ""))
    igwas_df = pl.read_csv(igwas_path, separator="\t", columns=["phenotype_id", "variant_id", "p_value"])

    for gwas_path in tqdm.tqdm(sorted(direct_paths)):
        phenotype_id = gwas_path.stem.replace(".glm.linear", "").replace("plink.", "")
        direct_df = (
            pl.read_csv(gwas_path, separator="\t", columns=["ID", "P"])
            .rename({"ID": "variant_id", "P": "p_value"})
        )
    
        merged_df = (
            igwas_df
            .filter(pl.col("phenotype_id").eq(phenotype_id))
            .join(direct_df, on=["variant_id"], suffix="_direct")
            .with_columns(p_value_direct=-pl.col("p_value_direct").log(10))
        )
        
        r2 = (
            merged_df
            .select(
                rss=(pl.col("p_value") - pl.col("p_value_direct")).pow(2).sum(),
                tss=(pl.col("p_value_direct") - pl.col("p_value_direct").mean()).pow(2).sum(),
            )
            .select(r2=1 - pl.col("rss") / pl.col("tss"))
            .item()
        )
        result_row = {
            "phenotype_id": phenotype_id,
            "fraction_pcs": fraction,
            "r2": r2,
        }
        fit_quality.append(result_row)

        sampled_df = merged_df.join(sampled_variant_df, on="variant_id").with_columns(fraction=fraction)
        if full_sampled_df is None:
            full_sampled_df = sampled_df
        else:
            full_sampled_df = pl.concat([full_sampled_df, sampled_df])

    fit_quality_df = pl.DataFrame(fit_quality)
    
    fit_quality_df.write_parquet(f"plot_data/fit_quality_{fraction}.parquet")
    full_sampled_df.write_parquet(f"plot_data/sampled_variants_pvalues_{fraction}.parquet")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1238/1238 [3:37:51<00:00, 10.56s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1238/1238 [3:36:35<00:00, 10.50s/it]
 51%|█████████████████████████████████████████████████▊                                               | 636/1238 [1:50:58<1:42:38, 10.23s/it]