# UK Biobank (UKB) - Plot PC scores on map of UK

In [None]:
import cartopy.crs as ccrs
import cartopy.io.shapereader as shpreader
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import seaborn as sns
import pathlib
import hail as hl

hl.init(spark_conf={'spark.driver.memory': '12g'})

In [None]:
# Lifted from seaborn source code
def _freedman_diaconis_bins(a):
    """Calculate number of hist bins using Freedman-Diaconis rule."""
    # From https://stats.stackexchange.com/questions/798/
    a = np.asarray(a)
    if len(a) < 2:
        return 1
    iqr = np.subtract.reduce(np.nanpercentile(a, [75, 25]))
    h = 2 * iqr / (len(a) ** (1 / 3))
    # fall back to sqrt(a) bins if iqr is 0
    if h == 0:
        return int(np.sqrt(a.size))
    else:
        return int(np.ceil((a.max() - a.min()) / h))

## White British UKB subset

### Scatter plot:

In [None]:
plt.ioff()

bucket = 'ukb-data'
version = 'genotypes'
samples = '337111-samples'
gcs_prefix = f'gs://{bucket}/{version}/{samples}/pca-sm-whitened-02'

# Load the birth location coordinates Hail Table
if samples == '337111-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/wb_337111-uk_birth_coordinates.ht'
    ).key_by()
elif samples == '406696-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/pan_ukb_406696-uk_birth_coordinates.ht'
    ).key_by()
coord_ht = coord_ht.annotate(s=hl.str(coord_ht.s)).key_by('s')

# Set number of PC score plots to make, number of rows/cols for subplots, and individual plot size
n_pcs = 30
nrow, ncol = 6, 5
fig_width, fig_height = 3, 4.25

for ws in [0, 30, 100, 300]:   
    print(f'Window size = {ws}:')
    # Load PC scores Table and join with birth location coordinates Table
    scores_ht = hl.read_table(f'{gcs_prefix}/full-scores-ws{ws}-k100.ht')
    scores_ht = scores_ht.join(coord_ht).persist()

    # Collect east/north coordinates to numpy arrays
    x_coords = np.array(scores_ht['east_coord_130'].collect())
    y_coords = np.array(scores_ht['north_coord_129'].collect())

    m_variants = 147604
    n_samples = int(samples.split('-')[0])

    # Set plot title and output directory
    if samples == '406696-samples':
        subset = 'Pan-UKB'
    elif samples == '337111-samples':
        subset = 'White British'
    plot_title = f'\nUK Biobank PC Scores by Place of Birth in UK, {subset}\nPCs computed on {n_samples} samples and {m_variants} SNPs, whitening window size = {ws}.\nPlace of birth in UK data available for {scores_ht.count()} samples.\n'
    output_dir = f'/Users/pcumming/pca/UKB/plots/{samples}'
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Create figure and set of subplots, set plot title
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(nrow, ncol, 
                            figsize=(fig_width*ncol, fig_height*nrow), 
                            subplot_kw={'projection': ccrs.OSGB(approx=False)},
                            sharex=False, sharey=False, 
                            constrained_layout=True)    
    fig.suptitle(plot_title, fontsize=14, ha='center')

    # Use k_idx to index PC scores, from 0 to n_pcs - 1
    k_idx = 0
    for i in range(nrow):
        for j in range(ncol):
            print(f'k_idx = {k_idx}, i = {i}, j = {j}.')

            # Collect the kth PC score to a numpy array
            pc_score_k = np.array(scores_ht.scores[k_idx].collect())

            # Set zero-centered normalization for colormap
            halfrange = np.mean(pc_score_k) + (2 * np.std(pc_score_k))
            norm = mcolors.CenteredNorm(vcenter=0, halfrange=halfrange)

            # Create the (i,j)-th scatter plot
            axs[i, j].set_title(f'PC{k_idx + 1}, w = {ws}', loc='left')
            axs[i, j].coastlines(resolution='10m')
            ax_current = axs[i, j].scatter(x=x_coords, y=y_coords, 
                                           s=0.5, c=pc_score_k,
                                           marker='.', alpha=0.25, 
                                           linewidths=1, edgecolors='face',
                                           norm=norm, cmap='bwr', 
                                           transform=ccrs.OSGB(approx=False))
            cb = fig.colorbar(ax_current, ax=axs[i, j], location='right', pad=0.01, fraction=0.1, aspect=30)
            cb.ax.set_title('Score', loc='left')
            k_idx += 1

    fname_out = f'uk_map-full-scores-ws{ws}-k100'

    print('Writing png file...')
    plt.savefig(f'{output_dir}/{fname_out}-scatter_plot.png', dpi=300, bbox_inches='tight')

    print('Writing pdf file...')
    plt.savefig(f'{output_dir}/{fname_out}-scatter_plot.pdf', dpi=300, bbox_inches='tight')

    plt.close(fig)

### Hexbin plot:

In [None]:
plt.ioff()

bucket = 'ukb-data'
version = 'genotypes'
samples = '337111-samples'
gcs_prefix = f'gs://{bucket}/{version}/{samples}/pca-sm-whitened-02'

# Load the birth location coordinates Hail Table
if samples == '337111-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/wb_337111-uk_birth_coordinates.ht'
    ).key_by()
elif samples == '406696-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/pan_ukb_406696-uk_birth_coordinates.ht'
    ).key_by()
coord_ht = coord_ht.annotate(s=hl.str(coord_ht.s)).key_by('s')

# Set number of PC score plots to make, number of rows/cols for subplots, and individual plot size
n_pcs = 30
nrow, ncol = 6, 5
fig_width, fig_height = 3, 4.25

for ws in [0, 30, 100, 300]:   
    print(f'Window size = {ws}:')
    # Load PC scores Table and join with birth location coordinates Table
    scores_ht = hl.read_table(f'{gcs_prefix}/full-scores-ws{ws}-k100.ht')
    scores_ht = scores_ht.join(coord_ht).persist()

    # Collect east/north coordinates to numpy arrays
    x_coords = np.array(scores_ht['east_coord_130'].collect())
    y_coords = np.array(scores_ht['north_coord_129'].collect())

    m_variants = 147604
    n_samples = int(samples.split('-')[0])

    # Set plot title and output directory
    if samples == '406696-samples':
        subset = 'Pan-UKB'
    elif samples == '337111-samples':
        subset = 'White British'
    plot_title = f'\nUK Biobank PC Scores by Place of Birth in UK, {subset}\nPCs computed on {n_samples} samples and {m_variants} SNPs, whitening window size = {ws}.\nPlace of birth in UK data available for {scores_ht.count()} samples.\n'
    output_dir = f'/Users/pcumming/pca/UKB/plots/{samples}'
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Create figure and set of subplots, set plot title
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(nrow, ncol, 
                            figsize=(fig_width*ncol, fig_height*nrow), 
                            subplot_kw={'projection': ccrs.OSGB(approx=False)},
                            sharex=False, sharey=False, 
                            constrained_layout=True)    
    fig.suptitle(plot_title, fontsize=14, ha='center')

    # Use k_idx to index PC scores, from 0 to n_pcs - 1
    k_idx = 0
    for i in range(nrow):
        for j in range(ncol):
            print(f'k_idx = {k_idx}, i = {i}, j = {j}.')

            # Collect the kth PC score to a numpy array
            pc_score_k = np.array(scores_ht.scores[k_idx].collect())

            # Set zero-centered normalization for colormap
            halfrange = np.mean(pc_score_k) + (2 * np.std(pc_score_k))
#             norm = mcolors.CenteredNorm(vcenter=0, halfrange=halfrange)
            norm = mcolors.Normalize(vmin=-halfrange, vmax=halfrange)

            x_bins = min(_freedman_diaconis_bins(x_coords), 75)
            y_bins = min(_freedman_diaconis_bins(y_coords), 75)
            gridsize = int(np.mean([x_bins, y_bins]))

            # Create the (i,j)-th hexbin plot
            axs[i, j].set_title(f'PC{k_idx + 1}, w = {ws}', loc='left')
            axs[i, j].coastlines(resolution='10m')
            ax_current = axs[i, j].hexbin(x=x_coords, y=y_coords, 
                                          C=pc_score_k, 
                                          reduce_C_function=np.mean,
                                          gridsize=gridsize, 
                                          linewidths=0.01,
                                          norm=norm, cmap='bwr', 
                                          mincnt=None, 
                                          transform=ccrs.OSGB(approx=False))
            cb = fig.colorbar(ax_current, ax=axs[i, j], location='right', pad=0.01, fraction=0.1, aspect=30)
            cb.ax.set_title('Mean\nscore', loc='left')
            k_idx += 1

    fname_out = f'uk_map-full-scores-ws{ws}-k100'

    print('Writing png file...')
    plt.savefig(f'{output_dir}/{fname_out}-hexbin_plot.png', dpi=300, bbox_inches='tight')

    print('Writing pdf file...')
    plt.savefig(f'{output_dir}/{fname_out}-hexbin_plot.pdf', dpi=300, bbox_inches='tight')

    plt.close(fig)

### Scatter with bottom/top percentile scores plot:

In [None]:
plt.ioff()

bucket = 'ukb-data'
version = 'genotypes'
samples = '337111-samples'
gcs_prefix = f'gs://{bucket}/{version}/{samples}/pca-sm-whitened-02'

# Load the birth location coordinates Hail Table
if samples == '337111-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/wb_337111-uk_birth_coordinates.ht'
    ).key_by()
elif samples == '406696-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/pan_ukb_406696-uk_birth_coordinates.ht'
    ).key_by()
coord_ht = coord_ht.annotate(s=hl.str(coord_ht.s)).key_by('s')

# Set number of PC score plots to make, number of rows/cols for subplots, and individual plot size
n_pcs = 30
nrow, ncol = 6, 5
fig_width, fig_height = 3, 4.25

bottom_percentile = 0.05
top_percentile = 1 - bottom_percentile

for ws in [0, 30, 100, 300]:   
    print(f'Window size = {ws}:')
    # Load PC scores Table and join with birth location coordinates Table
    scores_ht = hl.read_table(f'{gcs_prefix}/full-scores-ws{ws}-k100.ht')
    scores_ht = scores_ht.join(coord_ht).persist()

    # Collect east/north coordinates to numpy arrays
    x_coords = np.array(scores_ht['east_coord_130'].collect())
    y_coords = np.array(scores_ht['north_coord_129'].collect())

    m_variants = 147604
    n_samples = scores_ht.count() # int(samples.split('-')[0])

    # Set plot title and output directory
    if samples == '406696-samples':
        subset = 'Pan-UKB'
    elif samples == '337111-samples':
        subset = 'White British'
    plot_title = (f'\nUK Biobank PC Scores by Place of Birth in UK: {subset}, w = {ws}\n' 
                  f'PCs computed on {n_samples} samples and {m_variants} SNPs.' 
                  f'\nOnly samples with PC scores below the {int(bottom_percentile * 100)}th percentile'
                  f' or above the {int(top_percentile * 100)}th percentile shown on map.\n')
    output_dir = f'/Users/pcumming/pca/UKB/plots/{samples}'
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Create figure and set of subplots, set plot title
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(nrow, ncol, 
                            figsize=(fig_width*ncol, fig_height*nrow), 
                            subplot_kw={'projection': ccrs.OSGB(approx=False)},
                            sharex=False, sharey=False, 
                            constrained_layout=True)    
    fig.suptitle(plot_title, fontsize=14, ha='center')

    # Use k_idx to index PC scores, from 0 to n_pcs - 1
    k_idx = 0
    for i in range(nrow):
        for j in range(ncol):
            print(f'k_idx = {k_idx}, i = {i}, j = {j}.')

            # Collect the kth PC score to a numpy array
            pc_score_k = np.array(scores_ht.scores[k_idx].collect())

            # Set zero-centered normalization for colormap
            halfrange = np.mean(pc_score_k) + (2 * np.std(pc_score_k))
            norm = mcolors.CenteredNorm(vcenter=0, halfrange=halfrange)

            # Get indices for the bottom and top score percentiles
            quantiles = np.quantile(pc_score_k, [bottom_percentile, top_percentile])
            bottom_score_idxs = list(np.nonzero(pc_score_k <= quantiles[0])[0])
            top_score_idxs = list(np.nonzero(pc_score_k >= quantiles[1])[0])
            score_idxs = sorted(bottom_score_idxs + top_score_idxs)

            filtered_pc_score_k = pc_score_k[score_idxs]
            filtered_x_coords = x_coords[score_idxs]
            filtered_y_coords = y_coords[score_idxs]

            # Create the (i,j)-th scatter plot
            axs[i, j].set_title(f'PC{k_idx + 1}, w = {ws}', loc='left')
            axs[i, j].coastlines(resolution='10m')
            axs[i, j].set_xlim(37050.0, 684950.0)
            axs[i, j].set_ylim(0.0, 1268350.0)
            ax_current = axs[i, j].scatter(x=filtered_x_coords, y=filtered_y_coords, 
                                           s=1, c=filtered_pc_score_k,
                                           marker='.', alpha=0.25, 
                                           linewidths=0.1, edgecolors='face',
                                           norm=norm, cmap='bwr', 
                                           transform=ccrs.OSGB(approx=False))
            cb = fig.colorbar(ax_current, ax=axs[i, j], location='right', pad=0.01, fraction=0.1, aspect=30)
            cb.ax.set_title('Score', loc='left')
            k_idx += 1

    fname_out = f'uk_map-full-scores-ws{ws}-k100'

    print('Writing png file...')
    plt.savefig(f'{output_dir}/{fname_out}-scatter_plot-{int(bottom_percentile * 100)}_percent.png', dpi=300, bbox_inches='tight')

    print('Writing pdf file...')
    plt.savefig(f'{output_dir}/{fname_out}-scatter_plot-{int(bottom_percentile * 100)}_percent.pdf', dpi=300, bbox_inches='tight')

    print()
    plt.close(fig)

### Hexbin with bottom/top percentile scores plot:

In [None]:
plt.ioff()

bucket = 'gs://ukb-data'
version = 'genotypes'
samples = '337111-samples'
gcs_prefix = f'{bucket}/{version}/{samples}/pca-sm-whitened-02'

# Load the home location coordinates Hail Table
if samples == '337111-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/wb_337111-uk_birth_coordinates.ht'
    ).key_by()
elif samples == '406696-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/pan_ukb_406696-uk_birth_coordinates.ht'
    ).key_by()
coord_ht = coord_ht.annotate(s=hl.str(coord_ht.s)).key_by('s')

# Set number of PC score plots to make, number of rows/cols for subplots, and individual plot size
n_pcs = 30
nrow, ncol = 6, 5
fig_width, fig_height = 3, 4.25

bottom_percentile = 0.025
top_percentile = 1 - bottom_percentile

for ws in [0, 300]:   
    print(f'Window size = {ws}:')
    # Load PC scores Table and join with birth location coordinates Table
    scores_ht = hl.read_table(f'{gcs_prefix}/full-scores-ws{ws}-k100.ht')
    scores_ht = scores_ht.join(coord_ht).persist()

    # Collect east/north coordinates to numpy arrays
    x_coords = np.array(scores_ht['east_coord_130'].collect())
    y_coords = np.array(scores_ht['north_coord_129'].collect())

    m_variants = 147604
    n_samples = int(samples.split('-')[0])

    # Set plot title and output directory
    if samples == '406696-samples':
        subset = 'Pan-UKB'
    elif samples == '337111-samples':
        subset = 'White British'
    plot_title = f'\nUK Biobank PC Scores by Place of Birth in UK, {subset}\nPCs computed on {n_samples} samples and {m_variants} SNPs, whitening window size = {ws}.\nOnly samples with PC scores below the {int(bottom_percentile * 100)}th percentile or above the {int((1 - bottom_percentile) * 100)}th percentile shown on map.\n'
    output_dir = f'/Users/pcumming/pca/UKB/plots/{samples}'
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Create figure and set of subplots, set plot title
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(nrow, ncol, 
                            figsize=(fig_width*ncol, fig_height*nrow), 
                            subplot_kw={'projection': ccrs.OSGB(approx=False)},
                            sharex=False, sharey=False, 
                            constrained_layout=True)    
    fig.suptitle(plot_title, fontsize=14, ha='center')

    # Use k_idx to index PC scores, from 0 to n_pcs - 1
    k_idx = 0
    for i in range(nrow):
        for j in range(ncol):
            print(f'k_idx = {k_idx}, i = {i}, j = {j}.')

            # Collect the kth PC score to a numpy array
            pc_score_k = np.array(scores_ht.scores[k_idx].collect())

            # Set zero-centered normalization for colormap
            halfrange = np.mean(pc_score_k) + (2 * np.std(pc_score_k))
#             norm = mcolors.CenteredNorm(vcenter=0, halfrange=halfrange)
            norm = mcolors.Normalize(vmin=-halfrange, vmax=halfrange)

            # Get indices for the bottom and top score percentiles
            quantiles = np.quantile(pc_score_k, [bottom_percentile, top_percentile])
            bottom_score_idxs = list(np.nonzero(pc_score_k <= quantiles[0])[0])
            top_score_idxs = list(np.nonzero(pc_score_k >= quantiles[1])[0])
            score_idxs = sorted(bottom_score_idxs + top_score_idxs)

            filtered_pc_score_k = pc_score_k[score_idxs]
            filtered_x_coords = x_coords[score_idxs]
            filtered_y_coords = y_coords[score_idxs]

            x_bins = min(_freedman_diaconis_bins(filtered_x_coords), 75)
            y_bins = min(_freedman_diaconis_bins(filtered_y_coords), 75)
            gridsize = int(np.mean([x_bins, y_bins]))

            # Create the (i,j)-th hexbin plot
            axs[i, j].set_title(f'PC{k_idx + 1}, w = {ws}', loc='left')
            axs[i, j].coastlines(resolution='10m')
            ax_current = axs[i, j].hexbin(x=filtered_x_coords, y=filtered_y_coords, 
                                          C=filtered_pc_score_k, 
                                          reduce_C_function=np.mean,
                                          gridsize=gridsize, 
                                          linewidths=0.01,
                                          norm=norm, cmap='bwr', 
                                          mincnt=None, 
                                          transform=ccrs.OSGB(approx=False))
            cb = fig.colorbar(ax_current, ax=axs[i, j], location='right', pad=0.01, fraction=0.1, aspect=30)
            cb.ax.set_title('Mean\nscore', loc='left')
            k_idx += 1

    fname_out = f'uk_map-full-scores-ws{ws}-k100'

    print('Writing png file...')
    plt.savefig(f'{output_dir}/{fname_out}-hexbin_plot-{int(bottom_percentile * 100)}_percent.png', dpi=300, bbox_inches='tight')

    print('Writing pdf file...')
    plt.savefig(f'{output_dir}/{fname_out}-hexbin_plot-{int(bottom_percentile * 100)}_percent.pdf', dpi=300, bbox_inches='tight')

    plt.close(fig)

## Pan-UKB subset

### Scatter plot:

In [None]:
plt.ioff()

bucket = 'gs://ukb-data'
version = 'genotypes'
samples = '406696-samples'
gcs_prefix = f'{bucket}/{version}/{samples}/pca-sm-whitened-02'

# Load the home location coordinates Hail Table
if samples == '337111-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/wb_337111-uk_birth_coordinates.ht'
    ).key_by()
elif samples == '406696-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/pan_ukb_406696-uk_birth_coordinates.ht'
    ).key_by()
coord_ht = coord_ht.annotate(s=hl.str(coord_ht.s)).key_by('s')

# Set number of PC score plots to make, number of rows/cols for subplots, and individual plot size
n_pcs = 40
nrow, ncol = 8, 5
fig_width, fig_height = 3, 4.25

for ws in [0, 30, 100, 300]:   
    print(f'Window size = {ws}:')
    # Load PC scores Table and join with birth location coordinates Table
    scores_ht = hl.read_table(f'{gcs_prefix}/full-scores-ws{ws}-k100.ht')
    scores_ht = scores_ht.join(coord_ht).persist()

    # Collect east/north coordinates to numpy arrays
    x_coords = np.array(scores_ht['east_coord_130'].collect())
    y_coords = np.array(scores_ht['north_coord_129'].collect())

    m_variants = 147604
    n_samples = int(samples.split('-')[0])

    # Set plot title and output directory
    if samples == '406696-samples':
        subset = 'Pan-UKB'
    elif samples == '337111-samples':
        subset = 'White British'
    plot_title = f'\nUK Biobank PC Scores by Place of Birth in UK, {subset}\nPCs computed on {n_samples} samples and {m_variants} SNPs, whitening window size = {ws}.\nPlace of birth in UK data available for {scores_ht.count()} samples.\n'
    output_dir = f'/Users/pcumming/pca/UKB/plots/{samples}'
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Create figure and set of subplots, set plot title
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(nrow, ncol, 
                            figsize=(fig_width*ncol, fig_height*nrow), 
                            subplot_kw={'projection': ccrs.OSGB(approx=False)},
                            sharex=False, sharey=False, 
                            constrained_layout=True)    
    fig.suptitle(plot_title, fontsize=14, ha='center')

    # Use k_idx to index PC scores, from 0 to n_pcs - 1
    k_idx = 0
    for i in range(nrow):
        for j in range(ncol):
            print(f'k_idx = {k_idx}, i = {i}, j = {j}.')

            # Collect the kth PC score to a numpy array
            pc_score_k = np.array(scores_ht.scores[k_idx].collect())

            # Set zero-centered normalization for colormap
            halfrange = np.mean(pc_score_k) + (2 * np.std(pc_score_k))
            norm = mcolors.CenteredNorm(vcenter=0, halfrange=halfrange)

            # Create the (i,j)-th scatter plot
            axs[i, j].set_title(f'PC{k_idx + 1}, w = {ws}', loc='left')
            axs[i, j].coastlines(resolution='10m')
            ax_current = axs[i, j].scatter(x=x_coords, y=y_coords, 
                                           s=0.5, c=pc_score_k,
                                           marker='.', alpha=0.25, 
                                           linewidths=1, edgecolors='face',
                                           norm=norm, cmap='bwr', 
                                           transform=ccrs.OSGB(approx=False))
            cb = fig.colorbar(ax_current, ax=axs[i, j], location='right', pad=0.01, fraction=0.1, aspect=30)
            cb.ax.set_title('Score', loc='left')
            k_idx += 1

    fname_out = f'uk_map-full-scores-ws{ws}-k100'

    print('Writing png file...')
    plt.savefig(f'{output_dir}/{fname_out}-scatter_plot.png', dpi=300, bbox_inches='tight')

    print('Writing pdf file...')
    plt.savefig(f'{output_dir}/{fname_out}-scatter_plot.pdf', dpi=300, bbox_inches='tight')

    plt.close(fig)

### Hexbin plot:

In [None]:
plt.ioff()

bucket = 'gs://ukb-data'
version = 'genotypes'
samples = '406696-samples'
gcs_prefix = f'{bucket}/{version}/{samples}/pca-sm-whitened-02'

# Load the home location coordinates Hail Table
if samples == '337111-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/wb_337111-uk_birth_coordinates.ht'
    ).key_by()
elif samples == '406696-samples':
    coord_ht = hl.read_table(
        'gs://ukb-data/samples/pan_ukb_406696-uk_birth_coordinates.ht'
    ).key_by()
coord_ht = coord_ht.annotate(s=hl.str(coord_ht.s)).key_by('s')

# Set number of PC score plots to make, number of rows/cols for subplots, and individual plot size
n_pcs = 40
nrow, ncol = 8, 5
fig_width, fig_height = 3, 4.25

for ws in [0, 30, 100, 300]:   
    print(f'Window size = {ws}:')
    # Load PC scores Table and join with birth location coordinates Table
    scores_ht = hl.read_table(f'{gcs_prefix}/full-scores-ws{ws}-k100.ht')
    scores_ht = scores_ht.join(coord_ht).persist()

    # Collect east/north coordinates to numpy arrays
    x_coords = np.array(scores_ht['east_coord_130'].collect())
    y_coords = np.array(scores_ht['north_coord_129'].collect())

    m_variants = 147604
    n_samples = int(samples.split('-')[0])

    # Set plot title and output directory
    if samples == '406696-samples':
        subset = 'Pan-UKB'
    elif samples == '337111-samples':
        subset = 'White British'
    plot_title = f'\nUK Biobank PC Scores by Place of Birth in UK, {subset}\nPCs computed on {n_samples} samples and {m_variants} SNPs, whitening window size = {ws}.\nPlace of birth in UK data available for {scores_ht.count()} samples.\n'
    output_dir = f'/Users/pcumming/pca/UKB/plots/{samples}'
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Create figure and set of subplots, set plot title
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(nrow, ncol, 
                            figsize=(fig_width*ncol, fig_height*nrow), 
                            subplot_kw={'projection': ccrs.OSGB(approx=False)},
                            sharex=False, sharey=False, 
                            constrained_layout=True)    
    fig.suptitle(plot_title, fontsize=14, ha='center')

    # Use k_idx to index PC scores, from 0 to n_pcs - 1
    k_idx = 0
    for i in range(nrow):
        for j in range(ncol):
            print(f'k_idx = {k_idx}, i = {i}, j = {j}.')

            # Collect the kth PC score to a numpy array
            pc_score_k = np.array(scores_ht.scores[k_idx].collect())

            # Set zero-centered normalization for colormap
            halfrange = np.mean(pc_score_k) + (2 * np.std(pc_score_k))
#             norm = mcolors.CenteredNorm(vcenter=0, halfrange=halfrange)
            norm = mcolors.Normalize(vmin=-halfrange, vmax=halfrange)

            x_bins = min(_freedman_diaconis_bins(x_coords), 75)
            y_bins = min(_freedman_diaconis_bins(y_coords), 75)
            gridsize = int(np.mean([x_bins, y_bins]))

            # Create the (i,j)-th hexbin plot
            axs[i, j].set_title(f'PC{k_idx + 1}, w = {ws}', loc='left')
            axs[i, j].coastlines(resolution='10m')
            ax_current = axs[i, j].hexbin(x=x_coords, y=y_coords, 
                                          C=pc_score_k, 
                                          reduce_C_function=np.mean,
                                          gridsize=gridsize, 
                                          linewidths=0.01,
                                          norm=norm, cmap='bwr', 
                                          mincnt=None, 
                                          transform=ccrs.OSGB(approx=False))
            cb = fig.colorbar(ax_current, ax=axs[i, j], location='right', pad=0.01, fraction=0.1, aspect=30)
            cb.ax.set_title('Mean\nscore', loc='left')
            k_idx += 1

    fname_out = f'uk_map-full-scores-ws{ws}-k100'

    print('Writing png file...')
    plt.savefig(f'{output_dir}/{fname_out}-hexbin_plot.png', dpi=300, bbox_inches='tight')

    print('Writing pdf file...')
    plt.savefig(f'{output_dir}/{fname_out}-hexbin_plot.pdf', dpi=300, bbox_inches='tight')

    plt.close(fig)