In [7]:
import os
import re 
from functools import reduce
from pathlib import Path
from collections import OrderedDict
from typing import Tuple, Dict, List

from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, MapType, StructType, ArrayType, FloatType, IntegerType, StringType

import hail
import pandas as pd
import numpy as np
from pyarrow import csv, parquet, array
from google.cloud import storage

from data_pipeline.datasets.tob import helpers

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
    "/Users/daniel/Library/Mobile Documents/com~apple~CloudDocs/Documents/Work/Garvan/keys/tob-wgs-browser-browser-dev-sa.json"
)


In [3]:
#Create PySpark SparkSession
spark = SparkSession.builder \
    .master("local[1]") \
    .appName("tob-wgs") \
    .config("spark.driver.memory", "16G")\
    .config("spark.sql.caseSensitive", True)\
    .getOrCreate()

22/04/21 15:48:22 WARN Utils: Your hostname, macbook.local resolves to a loopback address: 127.0.0.1; using 192.168.0.3 instead (on interface en0)
22/04/21 15:48:22 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
22/04/21 15:48:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/21 15:48:23 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


# Helper functions

In [4]:
def columns_from_file(path, bucket):
    relative_path = resolve_path(path, bucket)
    # blob = bucket.get_blob(relative_path)

    columns = []
    with open(relative_path, 'r') as handle:
        columns = handle.readline().split("\t")
    
    return [c.strip() for c in columns if c and c.strip()]


def resolve_path(path, bucket=None):
    if not bucket:
        path = path.replace(f"gs://{helpers.get_gcp_bucket_name()}/", "")
        if "/Users/daniel/buckets/cpg/cpg-tob-wgs-browser-dev/" not in path:
           return "/Users/daniel/buckets/cpg/cpg-tob-wgs-browser-dev/" + path
        return path
    else:
        return path.replace(f"gs://{bucket.name}/", "")

# Table functions

In [6]:
def read_table(path, bucket, row_keys, verbose=False, annotations=None) -> hail.Table:
    relative_path = resolve_path(path, bucket)
    # blob = bucket.get_blob(relative_path)

    columns = []
    with open(relative_path, 'r') as handle:
        columns = handle.readline().split("\t")
    columns = [c.strip() for c in columns if c and c.strip()]

    row_fields = OrderedDict()
    for col in columns:
        row_fields[col] = hail.tstr if col in row_keys else hail.tfloat

    full_path = relative_path # f"gs://{bucket.name}/{blob.name}"
    if verbose:
        print(f"Loading from path '{full_path}'")
        
    table = hail.import_table(full_path, types=row_fields, delimiter="\t")
    
    if annotations:
        table = table.annotate(**annotations)
        row_fields.update({k: hail.tstr for k in annotations.keys()})
    
    return table.key_by(*row_keys)


def merge_tables(tables, unify=True) -> hail.Table:
    return reduce(lambda a, b: a.union(b, unify=unify), tables[1:], tables[0])

# Matrix table functions

In [7]:
def read_matrix_table(path, bucket, row_keys, verbose=False, annotations=None) -> hail.MatrixTable:
    relative_path = resolve_path(path, bucket)
    # blob = bucket.get_blob(relative_path)

    columns = []
    with open(relative_path, 'r') as handle:
        columns = handle.readline().split("\t")
    columns = [c.strip() for c in columns if c and c.strip()]

    row_fields = OrderedDict()
    for col in columns:
        row_fields[col] = hail.tstr if col in row_keys else hail.tfloat

    full_path = relative_path # f"gs://{bucket.name}/{blob.name}"
    if verbose:
        print(f"Loading from path '{full_path}'")
        
    table = hail.import_matrix_table(full_path, row_fields=row_fields, row_key=row_keys, delimiter="\t")
    
    if annotations:
        table = table.annotate(**annotations)
        row_fields.update({k: hail.tstr for k in annotations.keys()})
    
    return table


def merge_matrix_tables(tables, row_join_type="outer") -> hail.MatrixTable:
    return reduce(lambda a, b: a.union_cols(b, row_join_type=row_join_type), tables[1:], tables[0])

# Data processing

In [5]:
client = storage.Client()
bucket = client.bucket(helpers.get_gcp_bucket_name())

pattern = rf"{helpers.build_analaysis_input_path(absolute_path=False)}/Genotypes/genotype_(.*).tsv"
genotype_files = [
    b.name
    for b in bucket.list_blobs()
    if re.search(pattern, b.name)
]
print(genotype_files)

# Data for a cell type in the different chromosome files are the same. Save time a memory by loading one.
pattern = rf"{helpers.build_analaysis_input_path(absolute_path=False)}/Residuals/(.*)_chr1_log_residuals.tsv"
expression_files = [
    b.name
    for b in bucket.list_blobs()
    if re.search(pattern, b.name)
]
print(expression_files)

cell_types = [p.split('/')[-1].split('_')[0] for p in expression_files]

['full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr1.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr10.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr11.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr12.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr13.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr14.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr15.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr16.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr17.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr18.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr19.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr2.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr20.tsv', 'full_data/grch37/analysis_output/kccg/Genotypes/genotype_chr21.tsv', 'full_data/grch37/ana

# Expression summary pre-compute
Pre-compute the histogram and statistics for each gene in each cell-type

In [5]:
def read_dataframe(path: str, index_col="sampleid") -> pd.DataFrame:
    df = pd.read_table(path, header=0, index_col=False, sep="\t")
    df.set_index(index_col, inplace=True)
    return df


def generate_summary_statistics(df: pd.DataFrame, bins: int = 30, range: Tuple[float, float] = (-1, 1)) -> pd.DataFrame:
    histograms = df.apply(
        lambda x: compute_histogram(x, bins, range), 
        axis=1
    )
    statistics = df.apply(
        lambda x: {
            "min": float(np.min(x)),
            "max": float(np.max(x)),
            "median": float(np.median(x)),
            "mean": float(np.mean(x)),
            "q1": float(np.quantile(x, q=0.25)),
            "q3": float(np.quantile(x, q=0.75)),
            "iqr": float(np.quantile(x, q=0.75) - np.quantile(x, q=0.25)),
            "iqr_min": float(np.quantile(x, q=0.25) - 1.5 * (np.quantile(x, q=0.75) - np.quantile(x, q=0.25))),
            "iqr_max": float(np.quantile(x, q=0.75) + 1.5 * (np.quantile(x, q=0.75) - np.quantile(x, q=0.25))),
        }, 
        axis=1
    )

    return pd.DataFrame(
        {
            "histogram": histograms,
            "statistics": statistics,
        },
    )


def compute_histogram(data: pd.Series, bins: int = 30, range: Tuple[float, float] = (-1, 1)) -> Dict[str, List[int]]:
    counts, bin_edges = np.histogram(data, bins=bins, range=range)
    return {"counts": [int(x) for x in counts.tolist()], "bin_edges": [float(x) for x in bin_edges.tolist()]}


def merge_rows(group):
    data = {
        k: v 
        for k, v in zip(
            [r["cell_type_id"] for _, r in group.iterrows()], 
            [{"histogram": r["histogram"], "statistics": r["statistics"]} for _, r in group.iterrows()]
        )
    }

    return pd.DataFrame({
        "gene_id": [group["gene_id"].values[0]],
        "gene_symbol": [group["gene_symbol"].values[0]],
        "cell_type_ids": [data]
    })

In [None]:
schema = StructType([
    StructField(name="gene_id", dataType=StringType(), nullable=False),
    StructField(name="gene_symbol", dataType=StringType(), nullable=False),
    StructField(name="cell_type_id", dataType=StringType(), nullable=False),
    StructField(
        name="histogram",
        dataType=StructType([
            StructField(name="counts", dataType=ArrayType(IntegerType(), containsNull=False), nullable=False),
            StructField(name="bin_edges", dataType=ArrayType(FloatType(), containsNull=False), nullable=False),
        ]),
        nullable=False,
    ),
    StructField(
        name="statistics",
        dataType=StructType([
            StructField(name="min", dataType=FloatType(), nullable=False),
            StructField(name="max", dataType=FloatType(), nullable=False),
            StructField(name="mean", dataType=FloatType(), nullable=False),
            StructField(name="median", dataType=FloatType(), nullable=False),
            StructField(name="q1", dataType=FloatType(), nullable=False),
            StructField(name="q3", dataType=FloatType(), nullable=False),
            StructField(name="iqr", dataType=FloatType(), nullable=False),
            StructField(name="iqr_min", dataType=FloatType(), nullable=False),
            StructField(name="iqr_max", dataType=FloatType(), nullable=False),
        ]),
        nullable=False,
    )
])

# cell_types = ['bin']
for (index, cell_type) in enumerate(cell_types):
    print(f"Processing files for cell type '{cell_type}'")
    
    paths = [resolve_path(p, bucket=None) for p in expression_files if cell_type in p]
    if not paths: continue

    df: pd.DataFrame = pd.concat([read_dataframe(p) for p in paths], axis=0, join="outer").transpose()
    
    aggregate_df: pd.DataFrame = generate_summary_statistics(df)
    aggregate_df["gene_id"] = df.index  # TODO: set ensembl gene ids 
    aggregate_df["gene_symbol"] = df.index
    aggregate_df["cell_type_id"] = cell_type
    aggregate_df = aggregate_df[["gene_id", "gene_symbol", "cell_type_id", "histogram", "statistics"]]

    data = spark.createDataFrame(data=aggregate_df, schema=schema)
    data\
        .write\
        .mode("append" if index > 0 else "overwrite")\
        .partitionBy("cell_type_id")\
        .parquet("/Users/daniel/Desktop/expression_summary.parquet")  

# TSV to Parquet
Convert genotypes and expression `tsv` files to parquet for more efficient column-based querying

In [25]:
for chrom in range(1, 23):
    print(f"Reading data for chrom '{chrom}'")
    file = resolve_path(
        f"{helpers.build_analaysis_input_path()}/Genotypes/genotype_chr{chrom}.tsv", 
        bucket=None
    )

    types = {
        col: 'str' if col == "sampleid" else 'int8'
        for col in columns_from_file(file, bucket=None)
    }

    table = csv.read_csv(
        file,
        read_options=csv.ReadOptions(block_size=7e6),
        parse_options=csv.ParseOptions(delimiter="\t"),
        convert_options=csv.ConvertOptions(column_types=types)
    )
    table = table.rename_columns([c.replace(":", "_") for c in table.column_names])

    print(f"Writing to parquet for chrom '{chrom}'")
    root_path = Path(f"/Users/daniel/Desktop/genotypes/genotypes_chr{chrom}.parquet")
    
    # Write a dataset and collect metadata information of all written files
    metadata_collector = []
    parquet.write_to_dataset(table, root_path, metadata_collector=metadata_collector)

    # Write the ``_common_metadata`` parquet file without row groups statistics
    parquet.write_metadata(table.schema, root_path / '_common_metadata')

    # Write the ``_metadata`` parquet file with row groups statistics of all files
    parquet.write_metadata(
        table.schema, 
        root_path / '_metadata',
        metadata_collector=metadata_collector
    )

Reading data for chrom '1'
Writing to parquet for chrom '1'
Reading data for chrom '2'
Writing to parquet for chrom '2'
Reading data for chrom '3'
Writing to parquet for chrom '3'
Reading data for chrom '4'
Writing to parquet for chrom '4'
Reading data for chrom '5'
Writing to parquet for chrom '5'
Reading data for chrom '6'
Writing to parquet for chrom '6'
Reading data for chrom '7'
Writing to parquet for chrom '7'
Reading data for chrom '8'
Writing to parquet for chrom '8'
Reading data for chrom '9'
Writing to parquet for chrom '9'
Reading data for chrom '10'
Writing to parquet for chrom '10'
Reading data for chrom '11'
Writing to parquet for chrom '11'
Reading data for chrom '12'
Writing to parquet for chrom '12'
Reading data for chrom '13'
Writing to parquet for chrom '13'
Reading data for chrom '14'
Writing to parquet for chrom '14'
Reading data for chrom '15'
Writing to parquet for chrom '15'
Reading data for chrom '16'
Writing to parquet for chrom '16'
Reading data for chrom '17

In [26]:
for cell_type in cell_types:
    print(f"Reading data for cell type '{cell_type}'")
    file = resolve_path(
        f"{helpers.build_analaysis_input_path()}/Residuals/{cell_type}_chr1_log_residuals.tsv", 
        bucket=None
    )

    types = {
        col: 'str' if col == "sampleid" else 'float64'
        for col in columns_from_file(file, bucket=None)
    }

    table = csv.read_csv(
        file,
        parse_options=csv.ParseOptions(delimiter="\t"),
        convert_options=csv.ConvertOptions(column_types=types)
    )
    table = table.rename_columns([
        c.upper().replace(":", "_").replace("-", "_").replace(".", "_") 
        if c != "sampleid" else c 
        for c in table.column_names
    ])

    print(f"Writing to parquet for cell type '{cell_type}'")
    root_path = Path(f"/Users/daniel/Desktop/expression/{cell_type}.parquet")
    
    # Write a dataset and collect metadata information of all written files
    metadata_collector = []
    parquet.write_to_dataset(table, root_path, metadata_collector=metadata_collector)

    # Write the ``_common_metadata`` parquet file without row groups statistics
    parquet.write_metadata(table.schema, root_path / '_common_metadata')

    # Write the ``_metadata`` parquet file with row groups statistics of all files
    parquet.write_metadata(
        table.schema, 
        root_path / '_metadata',
        metadata_collector=metadata_collector
    )

Reading data for cell type 'bin'
Writing to parquet for cell type 'bin'
Reading data for cell type 'bmem'
Writing to parquet for cell type 'bmem'
Reading data for cell type 'cd4et'
Writing to parquet for cell type 'cd4et'
Reading data for cell type 'cd4nc'
Writing to parquet for cell type 'cd4nc'
Reading data for cell type 'cd4sox4'
Writing to parquet for cell type 'cd4sox4'
Reading data for cell type 'cd8et'
Writing to parquet for cell type 'cd8et'
Reading data for cell type 'cd8nc'
Writing to parquet for cell type 'cd8nc'
Reading data for cell type 'cd8s100b'
Writing to parquet for cell type 'cd8s100b'
Reading data for cell type 'dc'
Writing to parquet for cell type 'dc'
Reading data for cell type 'monoc'
Writing to parquet for cell type 'monoc'
Reading data for cell type 'mononc'
Writing to parquet for cell type 'mononc'
Reading data for cell type 'nk'
Writing to parquet for cell type 'nk'
Reading data for cell type 'nkr'
Writing to parquet for cell type 'nkr'
Reading data for cell 

In [27]:
file = "/Users/daniel/Desktop/genotypes.parquet"
column = '22_17302763_C'

df = spark.read.parquet("/Users/daniel/Desktop/genotypes/genotypes_chr22.parquet")

In [31]:
df = pd.read_parquet("/Users/daniel/Desktop/genotypes/genotypes_chr22.parquet", columns=["sampleid", column])

# rows = df.select("sampleid").filter(f"{column} > 0")
# rows.collect()

df.head()

Unnamed: 0,sampleid,22_17302763_C
0,1_1,0
1,2_2,0
2,3_3,0
3,4_4,1
4,6_6,1


# Association effect pre-compute
Pre-compute the histogram and statistics for each genotype for each unique eQTL with an FDR below 0.05

In [8]:
schema = StructType([
    StructField(name="association_id", dataType=StringType(), nullable=False),
    StructField(name="cell_type_id", dataType=StringType(), nullable=False),
    StructField(
        name="genotypes",
        dataType=MapType(
            keyType=StringType(),
            valueType=StructType([
                StructField(
                    name="histogram",
                    dataType=StructType([
                        StructField(name="counts", dataType=ArrayType(IntegerType(), containsNull=False), nullable=False),
                        StructField(name="bin_edges", dataType=ArrayType(FloatType(), containsNull=False), nullable=False),
                    ]),
                    nullable=False,
                ),
                StructField(
                    name="statistics",
                    dataType=StructType([
                        StructField(name="min", dataType=FloatType(), nullable=False),
                        StructField(name="max", dataType=FloatType(), nullable=False),
                        StructField(name="mean", dataType=FloatType(), nullable=False),
                        StructField(name="median", dataType=FloatType(), nullable=False),
                        StructField(name="q1", dataType=FloatType(), nullable=False),
                        StructField(name="q3", dataType=FloatType(), nullable=False),
                        StructField(name="iqr", dataType=FloatType(), nullable=False),
                        StructField(name="iqr_min", dataType=FloatType(), nullable=False),
                        StructField(name="iqr_max", dataType=FloatType(), nullable=False),
                    ]),
                    nullable=False,
                ),
            ]),
            valueContainsNull=True,
        ),
        nullable=False
    )
])


def statistics(x):
    if not x.size:
        return None

    return {
        "min": float(np.min(x)),
        "max": float(np.max(x)),
        "median": float(np.median(x)),
        "mean": float(np.mean(x)),
        "q1": float(np.quantile(x, q=0.25)),
        "q3": float(np.quantile(x, q=0.75)),
        "iqr": float(np.quantile(x, q=0.75) - np.quantile(x, q=0.25)),
        "iqr_min": float(np.quantile(x, q=0.25) - 1.5 * (np.quantile(x, q=0.75) - np.quantile(x, q=0.25))),
        "iqr_max": float(np.quantile(x, q=0.75) + 1.5 * (np.quantile(x, q=0.75) - np.quantile(x, q=0.25))),
    }


def histogram(x, bins=30, range=(-1,1)):
    if not x.size:
        return None
    return np.histogram(x, bins=bins, range=range)


def compute_summary(row, genotypes_df, expression_cache):
    genotype_id = f"{row.chrom}:{row.bp}_{row.a2}"
    gene_id = row.gene_symbol
    expression_df = expression_cache[row.cell_type_id]

    genotypes = {}
    for alt_count, genotype in enumerate([f"{row.a1}{row.a1}", f"{row.a1}{row.a2}", f"{row.a2}{row.a2}"]):
        samples = genotypes_df.loc[genotype_id][genotypes_df.loc[genotype_id] == alt_count].index.tolist()
        expression_values = expression_df.loc[[gene_id]][expression_df.columns.intersection(samples)]
        data = {
            "histogram": histogram(expression_values.values),
            "statistics": statistics(expression_values.values)
        } 
        genotypes[genotype] = data

    row = row[["association_id", "cell_type_id", "chrom"]]
    row["genotypes"] = genotypes
    return row