In [None]:
import numpy as np
import hail as hl
from hail.methods.pca import _make_tsm_from_call, _pca_and_moments

hl.init(tmp_dir='gs://ukb-data/tmp/ukb-grm')

# On dataproc cluster

## Randomly sample 10k/30k samples from UKB GT MatrixTable:

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 10000
overwrite = False
read = True

Create first set of samples and write to Hail Table:

In [None]:
mt = hl.read_matrix_table(f'{gcs_prefix}/gt_147604_406696.mt')
print(mt.count())

samples_list = mt.s.collect()
permuted_samples_list = list(np.random.permutation(samples_list))
samples_to_keep_list = list(permuted_samples_list[:N])

sample_set_01_ht = hl.Table.parallelize(
    hl.literal([{'s': s} for s in samples_to_keep_list], 'array<struct{s: str}>'), 
    n_partitions=1
)
sample_set_01_ht = sample_set_01_ht.key_by('s')
sample_set_01_ht = sample_set_01_ht.checkpoint(f'{gcs_prefix}/downsampled-{N}/set-01/samples.ht', 
                                               overwrite=overwrite, _read_if_exists=read)

downsampled_mt = mt.semi_join_cols(sample_set_01_ht)
downsampled_mt = downsampled_mt.repartition(8)
downsampled_mt = downsampled_mt.checkpoint(f'{gcs_prefix}/downsampled-{N}/set-01/gt.mt', 
                                           overwrite=overwrite, _read_if_exists=read)
print(downsampled_mt.count())

Create second, disjoint set of samples and write to Hail Table:

In [None]:
sample_set_01_ht = hl.read_table(f'{gcs_prefix}/downsampled-{N}/set-01/samples.ht')

mt = hl.read_matrix_table(f'{gcs_prefix}/gt_147604_406696.mt')
mt = mt.anti_join_cols(sample_set_01_ht)
print(mt.count())

samples_list = mt.s.collect()
permuted_samples_list = list(np.random.permutation(samples_list))
samples_to_keep_list = list(permuted_samples_list[:N])

sample_set_02_ht = hl.Table.parallelize(
    hl.literal([{'s': s} for s in samples_to_keep_list], 'array<struct{s: str}>'), 
    n_partitions=1
)
sample_set_02_ht = sample_set_02_ht.key_by('s')
sample_set_02_ht = sample_set_02_ht.checkpoint(f'{gcs_prefix}/downsampled-{N}/set-02/samples.ht', 
                                               overwrite=overwrite, _read_if_exists=read)

downsampled_mt = mt.semi_join_cols(sample_set_02_ht)
downsampled_mt = downsampled_mt.repartition(8)
downsampled_mt = downsampled_mt.checkpoint(f'{gcs_prefix}/downsampled-{N}/set-02/gt.mt', 
                                           overwrite=overwrite, _read_if_exists=read)
print(downsampled_mt.count())

Check that the two sample sets are disjoint:

In [None]:
sample_set_01_ht = hl.read_table(f'{gcs_prefix}/downsampled-{N}/set-01/samples.ht')
sample_set_02_ht = hl.read_table(f'{gcs_prefix}/downsampled-{N}/set-02/samples.ht')

set01 = set(sample_set_01_ht.collect())
print(f'Sample set 01, count: {len(set01)}.')

set02 = set(sample_set_02_ht.collect())
print(f'Sample set 02, count: {len(set02)}.')

print(f'Intersection of set 01 and set 02, count: {len(set01.intersection(set02))}.')

## Create GRM for 10k samples and compute spectrum:

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 10000
set_n = '02'
overwrite = True
parity = 'full'

In [None]:
downsampled_mt = hl.read_matrix_table(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/gt.mt')
m_variants = downsampled_mt.count()[0]
for whiten_ws in [0, 30, 100]:
    print(f'w = {whiten_ws}, {parity}.')
    if whiten_ws == 0:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=1000, 
            partition_size=1000, 
            hwe_normalize=True
        )
    elif whiten_ws == 30:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=30, 
            partition_size=900, 
            hwe_normalize=True,
            whiten_window_size=whiten_ws, 
            whiten_block_size=64
        )
    elif whiten_ws == 100:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=100, 
            partition_size=1000, 
            hwe_normalize=True,
            whiten_window_size=whiten_ws, 
            whiten_block_size=64
        )
    t = tsm.block_table
    block = tsm.block_expr

    print(f'w = {whiten_ws}, {parity}: Computing GRM...')
    grm = t.aggregate(hl.agg.ndarray_sum(block.T @ block)) / m_variants
    grm_bm = hl.linalg.BlockMatrix.from_numpy(grm)
    grm_bm.write(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}.bm', overwrite=overwrite)
    
    print(f'w = {whiten_ws}, {parity}: Computing full spectrum...')
    grm_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}.bm')
    eigvals = np.linalg.eigvalsh(grm_bm.to_numpy())
    eigvals_bm = hl.linalg.BlockMatrix.from_numpy(eigvals[::-1])
    eigvals_bm.write(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-eigenvalues.bm', overwrite=overwrite)

## Create GRM for 10k samples (odd/even split) and compute spectrum:

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 10000
set_n = '02'
overwrite = True

In [None]:
for parity in ['odd', 'even']:
    downsampled_mt = hl.read_matrix_table(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/gt.mt')
    if parity == 'odd':
        downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 != 0)
    if parity == 'even':
        downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 == 0)

    m_variants = downsampled_mt.count()[0]
    for whiten_ws in [0, 30, 100]:
        print(f'w = {whiten_ws}, {parity}.')
        if whiten_ws == 0:
            tsm = _make_tsm_from_call(
                call_expr=downsampled_mt.GT,
                block_size=1000, 
                partition_size=1000, 
                hwe_normalize=True
            )
        elif whiten_ws == 30:
            tsm = _make_tsm_from_call(
                call_expr=downsampled_mt.GT,
                block_size=30, 
                partition_size=900, 
                hwe_normalize=True,
                whiten_window_size=whiten_ws, 
                whiten_block_size=64
            )
        elif whiten_ws == 100:
            tsm = _make_tsm_from_call(
                call_expr=downsampled_mt.GT,
                block_size=100, 
                partition_size=1000, 
                hwe_normalize=True,
                whiten_window_size=whiten_ws, 
                whiten_block_size=64
            )
        t = tsm.block_table
        block = tsm.block_expr

        print(f'w = {whiten_ws}, {parity}: Computing GRM...')
        grm = t.aggregate(hl.agg.ndarray_sum(block.T @ block)) / m_variants
        grm_bm = hl.linalg.BlockMatrix.from_numpy(grm)
        grm_bm.write(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}.bm', overwrite=overwrite)

        print(f'w = {whiten_ws}, {parity}: Computing full spectrum...')
        grm_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}.bm')
        eigvals = np.linalg.eigvalsh(grm_bm.to_numpy())
        eigvals_bm = hl.linalg.BlockMatrix.from_numpy(eigvals[::-1])
        eigvals_bm.write(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-eigenvalues.bm', overwrite=overwrite)

## Create GRM for 30k samples with manual blocking:

### Compute GRM blocks (block size = 10k) on the 30k samples:

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 30000
set_n = '01'
overwrite = True
parity = 'full'

In [None]:
downsampled_mt = hl.read_matrix_table(f'{gcs_prefix}/set-{set_n}/downsampled-{N}/gt.mt')
# Filter to odd/even chromosomes, if required
if parity == 'odd':
    downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 != 0)
if parity == 'even':
    downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 == 0)

for whiten_ws in [0, 30, 100]:
    if whiten_ws == 0:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=1000, 
            partition_size=1000, 
            hwe_normalize=True
        )
    elif whiten_ws == 30:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=30, 
            partition_size=900, 
            hwe_normalize=True,
            whiten_window_size=whiten_ws, 
            whiten_block_size=64
        )
    elif whiten_ws == 100:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=100, 
            partition_size=1000, 
            hwe_normalize=True,
            whiten_window_size=whiten_ws, 
            whiten_block_size=64
        )
    t = tsm.block_table

    for i in range(3):
        for j in range(3):
            if i >= j:
                print(f'i = {i}, j = {j}: block_i slice [:, {i*10000}:{(i+1)*10000}], block_j slice [:, {j*10000}:{(j+1)*10000}].')
                block_i = tsm.block_expr[:, i*10000:(i+1)*10000]
                block_j = tsm.block_expr[:, j*10000:(j+1)*10000]
                grm_ij = t.aggregate(hl.agg.ndarray_sum(block_i.T @ block_j))
                grm_ij_bm = hl.linalg.BlockMatrix.from_numpy(grm_ij)
                grm_ij_bm.checkpoint(
                    f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_{i}{j}.bm',
                    overwrite=overwrite
                )

### Concatenate GRM blocks into the full GRM, write out:

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 30000
set_n = '01'
overwrite = True
parity = 'full'

In [None]:
downsampled_mt = hl.read_matrix_table(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/gt.mt')
if parity == 'odd':
    downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 != 0)
if parity == 'even':
    downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 == 0)
m_variants = downsampled_mt.count()[0]

for whiten_ws in [0, 30, 100]:
    # Read in each block from GCS 
    grm_00_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_00.bm')
    grm_10_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_10.bm')
    grm_11_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_11.bm')
    grm_20_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_20.bm')
    grm_21_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_21.bm')
    grm_22_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-block_22.bm')

    # Convert blocks to numpy arrays
    grm_22_np = grm_22_bm.to_numpy()
    grm_21_np = grm_21_bm.to_numpy()
    grm_20_np = grm_20_bm.to_numpy()
    grm_12_np = grm_21_np.T
    grm_11_np = grm_11_bm.to_numpy()
    grm_10_np = grm_10_bm.to_numpy()
    grm_02_np = grm_20_np.T
    grm_01_np = grm_10_np.T
    grm_00_np = grm_00_bm.to_numpy()

    # Horizontally stack the (10000, 10000) blocks (i0, i1, i2) into a single larger (10000, 30000) block
    grm_0j_np = np.hstack((grm_00_np, grm_01_np, grm_02_np))
    grm_1j_np = np.hstack((grm_10_np, grm_11_np, grm_12_np))
    grm_2j_np = np.hstack((grm_20_np, grm_21_np, grm_22_np))
    print(f'grm_0j_np shape = {grm_0j_np.shape}.')
    print(f'grm_1j_np shape = {grm_1j_np.shape}.')
    print(f'grm_2j_np shape = {grm_2j_np.shape}.')

    # Vertically stack the (10000, 30000) blocks from above into the full GRM with shape (30000, 30000)
    grm_np = np.vstack((grm_0j_np, grm_1j_np, grm_2j_np)) / m_variants
    print(f'grm_np shape = {grm_np.shape}.')

    # Convert full GRM to BlockMatrix and write to GCS
    grm_bm = hl.linalg.BlockMatrix.from_numpy(grm_np)
    grm_bm.write(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}.bm', overwrite=overwrite)

### Read in the GRM BlockMatrix, compute full spectrum (all eigenvalues):

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 30000
set_n = '01'
overwrite = True
parity = 'full'

In [None]:
for whiten_ws in [0, 30, 100]:
    print(f'{parity}, w = {whiten_ws}.')
    grm_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}.bm')
    eigvals = np.linalg.eigvalsh(grm_bm.to_numpy())
    eigvals_bm = hl.linalg.BlockMatrix.from_numpy(eigvals[::-1])
    eigvals_bm.write(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-eigenvalues.bm', overwrite=overwrite)

## Run PCA/SM estimator on the 10k/30k samples:

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 10000
set_n = '02'
parity = 'full'
overwrite = False

k = 100
n_parts_scores = 4
n_parts_loadings = 4

In [None]:
downsampled_mt = hl.read_matrix_table(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/gt.mt')
m_variants, n_samples = downsampled_mt.count()
for whiten_ws in [0, 30, 100]:
    print(f'w = {whiten_ws}.')
    if whiten_ws == 0:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=1000, 
            partition_size=1000, 
            hwe_normalize=True
        )
    elif whiten_ws == 30:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=30, 
            partition_size=900, 
            hwe_normalize=True,
            whiten_window_size=whiten_ws, 
            whiten_block_size=64
        )
    elif whiten_ws == 100:
        tsm = _make_tsm_from_call(
            call_expr=downsampled_mt.GT,
            block_size=100, 
            partition_size=1000, 
            hwe_normalize=True,
            whiten_window_size=whiten_ws, 
            whiten_block_size=64
        )

    # Run PCA/SM on TSM
    eigvals, scores, loadings, moments, stderrs = _pca_and_moments(
        tsm, 
        k=k, 
        num_moments=10, 
        compute_loadings=True, 
        q_iterations=10, 
        oversampling_param=10, 
        moment_samples=100
    )

    # Set the 0th spectral moment = n_samples, and the 0th standard error = missing
    eigvals = list(eigvals)
    moments = [n_samples] + list(moments)
    stderrs = hl.literal([None] + list(stderrs), 'array<float64>')

    scores_ht = f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-scores-ws{whiten_ws}-k{k}.ht'
    scores = scores.annotate_globals(
        name=scores_ht,
        eigenvalues=eigvals,
        spectral_moments=moments,
        standard_errors=stderrs,
        m_variants=m_variants,
        n_samples=n_samples
    )
    scores = scores.naive_coalesce(n_parts_scores)
    scores.write(scores_ht, overwrite=overwrite)

    loadings_ht = f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-loadings-ws{whiten_ws}-k{k}.ht'
    loadings = loadings.annotate_globals(
        name=loadings_ht,
        eigenvalues=eigvals,
        spectral_moments=moments,
        standard_errors=stderrs,
        m_variants=m_variants,
        n_samples=n_samples
    )
    loadings = loadings.naive_coalesce(n_parts_loadings)
    loadings.write(loadings_ht, overwrite=overwrite)

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 10000
set_n = '02'
overwrite = False

k = 100
n_parts_scores = 4
n_parts_loadings = 4

In [None]:
for parity in ['odd', 'even']:
    downsampled_mt = hl.read_matrix_table(f'{gcs_prefix}/downsampled-{N}/set-{set_n}/gt.mt')
    if parity == 'odd':
        downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 != 0)
    if parity == 'even':
        downsampled_mt = downsampled_mt.filter_rows(hl.int(downsampled_mt.locus.contig.replace('chr', '')) % 2 == 0)

    m_variants, n_samples = downsampled_mt.count()
    for whiten_ws in [0, 30, 100]:
        print(f'w = {whiten_ws}.')
        if whiten_ws == 0:
            tsm = _make_tsm_from_call(
                call_expr=downsampled_mt.GT,
                block_size=1000, 
                partition_size=1000, 
                hwe_normalize=True
            )
        elif whiten_ws == 30:
            tsm = _make_tsm_from_call(
                call_expr=downsampled_mt.GT,
                block_size=30, 
                partition_size=900, 
                hwe_normalize=True,
                whiten_window_size=whiten_ws, 
                whiten_block_size=64
            )
        elif whiten_ws == 100:
            tsm = _make_tsm_from_call(
                call_expr=downsampled_mt.GT,
                block_size=100, 
                partition_size=1000, 
                hwe_normalize=True,
                whiten_window_size=whiten_ws, 
                whiten_block_size=64
            )

        # Run PCA/SM on TSM
        eigvals, scores, loadings, moments, stderrs = _pca_and_moments(
            tsm, 
            k=k, 
            num_moments=10, 
            compute_loadings=True, 
            q_iterations=10, 
            oversampling_param=10, 
            moment_samples=100
        )

        # Set the 0th spectral moment = n_samples, and the 0th standard error = missing
        eigvals = list(eigvals)
        moments = [n_samples] + list(moments)
        stderrs = hl.literal([None] + list(stderrs), 'array<float64>')

        scores_ht = f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-scores-ws{whiten_ws}-k{k}.ht'
        scores = scores.annotate_globals(
            name=scores_ht,
            eigenvalues=eigvals,
            spectral_moments=moments,
            standard_errors=stderrs,
            m_variants=m_variants,
            n_samples=n_samples
        )
        scores = scores.naive_coalesce(n_parts_scores)
        scores.write(scores_ht, overwrite=overwrite)

        loadings_ht = f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-loadings-ws{whiten_ws}-k{k}.ht'
        loadings = loadings.annotate_globals(
            name=loadings_ht,
            eigenvalues=eigvals,
            spectral_moments=moments,
            standard_errors=stderrs,
            m_variants=m_variants,
            n_samples=n_samples
        )
        loadings = loadings.naive_coalesce(n_parts_loadings)
        loadings.write(loadings_ht, overwrite=overwrite)


# On local machine

## Read eigenvalues and write out to CSV:

In [None]:
import numpy as np
import pathlib
import hail as hl

hl.init()

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 10000
set_n = '02'
parity = 'full'

output_dir = '/Users/pcumming/pca/UKB/hdpca'
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

for parity in ['full', 'odd', 'even']:
    for whiten_ws in [0, 30, 100]:
        print(f'{parity}, w = {whiten_ws}.')
        eigenvalues_bm = hl.linalg.BlockMatrix.read(
            f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-eigenvalues.bm'
        )
        eigenvalues_np = eigenvalues_bm.to_numpy()[0].T
        np.savetxt(
            f'{output_dir}/{parity}-{N}-set-{set_n}-grm-ws{whiten_ws}-eigenvalues.csv', 
            eigenvalues_np, 
            delimiter=','
        )

In [None]:
gcs_prefix = 'gs://ukb-data/genotypes/406696-samples'
N = 30000
set_n = '01'
parity = 'full'

output_dir = '/Users/pcumming/pca/UKB/hdpca'
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

for whiten_ws in [0, 30, 100]:
    print(f'{parity}, w = {whiten_ws}.')
    eigenvalues_bm = hl.linalg.BlockMatrix.read(
        f'{gcs_prefix}/downsampled-{N}/set-{set_n}/{parity}-grm-ws{whiten_ws}-eigenvalues.bm'
    )
    eigenvalues_np = eigenvalues_bm.to_numpy()[0].T
    np.savetxt(
        f'{output_dir}/{parity}-{N}-set-{set_n}-grm-ws{whiten_ws}-eigenvalues.csv', 
        eigenvalues_np, 
        delimiter=','
    )

## Functions to load hdpca results from GCS:

In [None]:
from google.cloud import storage
import hail as hl
import json
import numpy as np
import pathlib

hl.init(spark_conf={'spark.driver.memory': '12g'})

In [None]:
# Valid argument values to pass to load_hdpc_est:
#     n_samples = 10000, 30000
#     window_size = 0, 30, 100
#     method = 'dgsp', 'lgsp', 'osp'
# The default method used in hdpc_est is `dgsp`

def load_hdpc_est(n_samples, sample_set, window_size, method, parity='full'):
    valid_n_samples = [10000, 30000]
    valid_window_sizes = [0, 30, 100]
    valid_methods = ['dgsp', 'lgsp', 'osp']
    assert n_samples in valid_n_samples, f'valid n_samples values: {valid_n_samples}.'
    assert window_size in valid_window_sizes, f'valid window_size values: {valid_window_sizes}.'
    assert method in valid_methods, f'valid method values: {valid_methods}.'

    if sample_set == 2 or parity in ['odd', 'even']:
        assert n_samples == 10000 and sample_set == 2, 'must set n_samples=10000 and sample_set=2 when parity=\'odd\' or parity=\'even\'.'

    gcs_prefix = 'genotypes/406696-samples/hdpca'
    storage_client = storage.Client()
    bucket = storage_client.get_bucket('ukb-data')
    blob = bucket.blob(f'{gcs_prefix}/hdpc_est-{parity}-{n_samples}-set-0{sample_set}-ws{window_size}-{method}.json')
    hdpc_est_results = json.loads(blob.download_as_string())
    hdpc_est_results['method'] = method
    return hdpc_est_results


def load_downsampled_ukb_scores(n_samples, sample_set, window_size, k, parity='full'):
    valid_n_samples = [10000, 30000]
    valid_window_sizes = [0, 30, 100]
    valid_ks = [100]
    assert n_samples in valid_n_samples, f'valid n_samples values: {valid_n_samples}.'
    assert window_size in valid_window_sizes, f'valid window_size values: {valid_window_sizes}.'
    assert k in valid_ks, f'valid k values: {valid_ks}.'

    if sample_set == 2 or parity in ['odd', 'even']:
        assert n_samples == 10000 and sample_set == 2, 'must set n_samples=10000 and sample_set=2 when parity=\'odd\' or parity=\'even\'.'

    gcs_prefix = f'gs://ukb-data/genotypes/406696-samples/downsampled-{n_samples}/set-0{sample_set}'
    ht = hl.read_table(f'{gcs_prefix}/{parity}-scores-ws{window_size}-k{k}.ht')
    scores = np.array(ht.scores.collect()).T
    return scores


def load_downsampled_ukb_loadings(n_samples, sample_set, window_size, k, parity='full'):
    valid_n_samples = [10000, 30000]
    valid_window_sizes = [0, 30, 100]
    valid_ks = [100]
    assert n_samples in valid_n_samples, f'valid n_samples values: {valid_n_samples}.'
    assert window_size in valid_window_sizes, f'valid window_size values: {valid_window_sizes}.'
    assert k in valid_ks, f'valid k values: {valid_ks}.'

    if sample_set == 2 or parity in ['odd', 'even']:
        assert n_samples == 10000 and sample_set == 2, 'must set n_samples=10000 and sample_set=2 when parity=\'odd\' or parity=\'even\'.'

    gcs_prefix = f'gs://ukb-data/genotypes/406696-samples/downsampled-{n_samples}/set-0{sample_set}'
    ht = hl.read_table(f'{gcs_prefix}/{parity}-loadings-ws{window_size}-k{k}.ht')
    loadings = np.array(ht.loadings.collect())
    return loadings


def load_downsampled_ukb_globals(n_samples, sample_set, window_size, k, parity='full'):
    valid_n_samples = [10000, 30000]
    valid_window_sizes = [0, 30, 100]
    valid_ks = [100]
    assert n_samples in valid_n_samples, f'valid n_samples values: {valid_n_samples}.'
    assert window_size in valid_window_sizes, f'valid window_size values: {valid_window_sizes}.'
    assert k in valid_ks, f'valid k values: {valid_ks}.'

    if sample_set == 2 or parity in ['odd', 'even']:
        assert n_samples == 10000 and sample_set == 2, 'must set n_samples=10000 and sample_set=2 when parity=\'odd\' or parity=\'even\'.'

    gcs_prefix = f'gs://ukb-data/genotypes/406696-samples/downsampled-{n_samples}/set-0{sample_set}'
    ht = hl.read_table(f'{gcs_prefix}/{parity}-scores-ws{window_size}-k{k}.ht')
    eigvals = np.array(hl.eval(ht.eigenvalues))
    spectral_moments = np.array(hl.eval(ht.spectral_moments))
    std_errs = np.array(hl.eval(ht.standard_errors))
    m_variants = hl.eval(ht.m_variants)
    return eigvals, spectral_moments, std_errs, m_variants


def load_downsampled_ukb_spectrum(n_samples, sample_set, window_size, k, parity='full'):
    valid_n_samples = [10000, 30000]
    valid_window_sizes = [0, 30, 100]
    valid_ks = [100]
    assert n_samples in valid_n_samples, f'valid n_samples values: {valid_n_samples}.'
    assert window_size in valid_window_sizes, f'valid window_size values: {valid_window_sizes}.'
    assert k in valid_ks, f'valid k values: {valid_ks}.'

    if sample_set == 2 or parity in ['odd', 'even']:
        assert n_samples == 10000 and sample_set == 2, 'must set n_samples=10000 and sample_set=2 when parity=\'odd\' or parity=\'even\'.'

    gcs_prefix = f'gs://ukb-data/genotypes/406696-samples/downsampled-{n_samples}/set-0{sample_set}'
    eigenvalues_bm = hl.linalg.BlockMatrix.read(f'{gcs_prefix}/{parity}-grm-ws{window_size}-eigenvalues.bm')
    eigenvalues_np = eigenvalues_bm.to_numpy()[0]
    return eigenvalues_np

### Quick check to verify functions work as intended:

In [None]:
w = 0  # 0, 30, 100

# To load hdpc_est results, set sample_set=2 for the odd/even split
hdpc_est_full = load_hdpc_est(n_samples=10000, sample_set=2, window_size=w, method='dgsp', parity='full')
hdpc_est_even = load_hdpc_est(n_samples=10000, sample_set=2, window_size=w, method='dgsp', parity='even')
hdpc_est_odd = load_hdpc_est(n_samples=10000, sample_set=2, window_size=w, method='dgsp', parity='odd')

# To load PCA/SM scores, set sample_set=2 for the odd/even split
scores_full = load_downsampled_ukb_scores(n_samples=10000, sample_set=2, window_size=w, k=100, parity='full')
scores_even = load_downsampled_ukb_scores(n_samples=10000, sample_set=2, window_size=w, k=100, parity='even')
scores_odd = load_downsampled_ukb_scores(n_samples=10000, sample_set=2, window_size=w, k=100, parity='odd')

# To load PCA/SM loadings, set sample_set=2 for the odd/even split
loadings_full = load_downsampled_ukb_loadings(n_samples=10000, sample_set=2, window_size=w, k=100, parity='full')
loadings_even = load_downsampled_ukb_loadings(n_samples=10000, sample_set=2, window_size=w, k=100, parity='even')
loadings_odd = load_downsampled_ukb_loadings(n_samples=10000, sample_set=2, window_size=w, k=100, parity='odd')

# To load PCA/SM results (eigenvalues, spectral moments, standard errors, variant count), set sample_set=2 for the odd/even split
evals_full, sm_full, stderr_full, m_full = load_downsampled_ukb_globals(n_samples=10000, sample_set=2, window_size=w, k=100, parity='full')
evals_even, sm_even, stderr_even, m_even = load_downsampled_ukb_globals(n_samples=10000, sample_set=2, window_size=w, k=100, parity='even')
evals_odd, sm_odd, stderr_odd, m_odd = load_downsampled_ukb_globals(n_samples=10000, sample_set=2, window_size=w, k=100, parity='odd')

# To load all GRM eigenvalues, set sample_set=2 for the odd/even split
spectrum_full = load_downsampled_ukb_spectrum(n_samples=10000, sample_set=2, window_size=w, k=100, parity='full')
spectrum_even = load_downsampled_ukb_spectrum(n_samples=10000, sample_set=2, window_size=w, k=100, parity='even')
spectrum_odd = load_downsampled_ukb_spectrum(n_samples=10000, sample_set=2, window_size=w, k=100, parity='odd')

In [None]:
print(scores_full.shape)
print(loadings_full.shape)
print()
print(scores_odd.shape)
print(loadings_odd.shape)
print()
print(scores_even.shape)
print(loadings_even.shape)

In [None]:
for n in [10000, 30000]:
    for ws in [0, 30, 100]:
        for method in ['dgsp', 'lgsp', 'osp']:
            print(load_hdpc_est(n_samples=n, window_size=ws, method=method))

## Compute cross-correlations:

In [None]:
def compute_crosscorr(nd1, nd2, k):
    # Compute matrix of cross-correlations, take off-diagonal block, run SVD and return squared singular values
    R = np.corrcoef(nd1, nd2)[:k, k:]
    s = np.linalg.svd(R, compute_uv=False)
    return s ** 2

### Cross-correlations between two disjoint 10k samples:

In [None]:
n = 10000
k = 100

output_path = f'/Users/pcumming/pca/UKB/npy/406696-samples/downsampled-{n}'
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

# Compute cross-correlations between sets 1 and 2 and write out results
for w in [0, 30, 100]:
    print(f'w = {w}.')
    loadings_01 = load_downsampled_ukb_loadings(n_samples=n, sample_set=1, window_size=w, k=k, parity='full')
    loadings_02 = load_downsampled_ukb_loadings(n_samples=n, sample_set=2, window_size=w, k=k, parity='full')
    R = np.corrcoef(loadings_01.T, loadings_02.T)[:k, k:]
    with open(f'{output_path}/set01_set02_cross_correlations-loadings-ws{w}-k{k}.npy', 'wb') as f:
        np.save(f, R)

In [None]:
# Read back in results from above
for w in [0, 30, 100]:
    print(f'w = {w}:')
    with open(f'{output_path}/set01_set02_cross_correlations-loadings-ws{w}-k{k}.npy', 'rb') as f:
        R = np.load(f)
    print(R.shape)
    print(R)

### Odd/even chromosome cross-correlations for a single 10k sample:

In [None]:
n = 10000
k = 100

output_path = f'/Users/pcumming/pca/UKB/npy/406696-samples/downsampled-{n}/set-02'
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

# Compute odd/even cross-correlations and write out results
for w in [0, 30, 100]:
    print(f'w = {w}.')
    loadings_odd = load_downsampled_ukb_loadings(n_samples=n, sample_set=2, window_size=w, k=k, parity='odd')
    loadings_even = load_downsampled_ukb_loadings(n_samples=n, sample_set=2, window_size=w, k=k, parity='even')
    R = np.corrcoef(loadings_odd.T, loadings_odd.T)[:k, k:]
    with open(f'{output_path}/odd_even_cross_correlations-loadings-ws{w}-k{k}.npy', 'wb') as f:
        np.save(f, R)

In [None]:
# Read back in results from above
for w in [0, 30, 100]:
    print(f'w = {w}:')
    with open(f'{output_path}/odd_even_cross_correlations-loadings-ws{w}-k{k}.npy', 'rb') as f:
        R = np.load(f)
    print(R)