# Imports and utility functions

In [30]:
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
import time, glob, gc, os, tifffile, cv2
import pandas as pd
import seaborn as sns
import xarray as xr
import tpae.data.ingest as ti

# define parameters
resolution = 10 # desired size of each pixel in microns
k = 10
repname = f'pca_k={k}_harmony'

# define and create directories
datadir = f'../../UC/UC-data/{resolution}u'
rawpixelsdir = f'{datadir}/../raw/'
downsampledpixelsdir = f'{datadir}/counts'
normedpixelsdir = f'{datadir}/normalized'
masksdir = f'{datadir}/masks'
processeddir = f'{datadir}/{repname}'
os.makedirs(datadir, exist_ok=True)
os.makedirs(downsampledpixelsdir, exist_ok=True)
os.makedirs(normedpixelsdir, exist_ok=True)
os.makedirs(masksdir, exist_ok=True)
os.makedirs(processeddir, exist_ok=True)

# read in markers and define marker subsets
with open(f'{datadir}/../channelNames.txt', 'r') as file: 
    markers = file.read().splitlines()
blanks = [marker for marker in markers if 'blank' in marker]
negctrls = [marker for marker in markers if 'empty' in marker]
stains = [marker for marker in markers if 'HOECHST' in marker or 'HOCHST' in marker or 'DRAQ5' in marker]

# define files to process
files = glob.glob(f'{rawpixelsdir}/run[1-36-8]/*.tif') #skipping runs 4 and 5 per authors' instructions

# downsample original images

In [32]:
metadata = pd.read_csv(f'{datadir}/../2024_10_16_UC_Patient_Metadata.csv')
import re
def fullpath_to_sampleinfo(f):
    run = int(os.path.basename(os.path.dirname(file))[3:])
    reg = int(os.path.basename(re.sub("00", "", file.split("_")[0]))[3:])
    cond = (metadata['Run#'] == run) & (metadata['Region #'] == reg)
    if cond.sum() > 0:
        return metadata[cond].iloc[0]
    else:
        print(f'skipping {file} because no metadata were found')
        return None

In [33]:
orig_pixel_size = 0.75488
downsample_factor = int(resolution//orig_pixel_size)

In [35]:
for file in ti.pb(files):
    sampleinfo = fullpath_to_sampleinfo(file)
    if sampleinfo is not None:
        sample = tifffile.TiffFile(file).asarray()
        sample = sample.reshape((-1, sample.shape[2], sample.shape[3])).transpose(1, 2, 0)
        sample = ti.hiresarray_to_downsampledxarray(sample,
                                                    sampleinfo['NEW Label'],
                                                    downsample_factor, orig_pixel_size, markers)
        sample.attrs['sid'] = sampleinfo['NEW Label']
        sample.attrs['donor'] = sampleinfo['Patient.ID']
        sample.to_netcdf(f'{downsampledpixelsdir}/{sample.attrs["sid"]}.nc', encoding={sample.name: ti.compression}, engine="netcdf4")
    
    gc.collect()

 67%|██████████████████████████████████████████▍                    | 29/43 [06:00<02:09,  9.26s/it]

skipping /n/data1/hms/dbmi/raychaudhuri/lab/lakshay/uc/uc-data/raw/run6/reg009_montage-005.tif because no metadata were found


100%|███████████████████████████████████████████████████████████████| 43/43 [09:04<00:00, 12.67s/it]


# create masks and normalize

In [None]:
def get_foreground(s):
    if s.name == '300-0529_Scan1':
        s = s.where(s.x < 12500, 0)
    return ti.foreground_mask_ihc(s, real_markers, neg_ctrls, 0.1, 12, blur_width=5)
ti.write_masks(downsampledpixelsdir, masksdir, get_foreground, sids)

In [None]:
def transform(X):
    return (X[:,:-1]) / (1+X[:,-1])[:,None]

def get_sumstats(pixels):
    pixels = transform(pixels)
    ntranscripts = pixels.sum(axis=1, dtype=np.float64)
    med_ntranscripts = np.median(ntranscripts)
    pixels = np.log1p(med_ntranscripts * pixels / (ntranscripts[:,None] + 1e-6)) # adding to denominator in case pixel is all 0s
    means = pixels.mean(axis=0, dtype=np.float64)
    stds = pixels.std(axis=0, dtype=np.float64)
    return {'means':means, 'stds':stds, 'med_ntranscripts':med_ntranscripts}

def normalize(mask, s, med_ntranscripts=None, means=None, stds=None):
    s = s.where(mask, other=0)
    pl = ti.xr_to_pixellist(s, mask)
    pl = transform(pl)
    pl = np.log1p(med_ntranscripts * pl / (pl.sum(axis=1)[:,None] + 1e-6)) # adding to denominator in case pixel is all 0s
    pl -= means
    pl /= stds
    s = s.sel(marker=markers[:-1])
    ti.set_pixels(s, mask, pl)
    s.attrs['med_ntranscripts'] = med_ntranscripts
    s.attrs['means'] = means
    s.attrs['stds'] = stds
    return s

ti.normalize_allsamples(downsampledpixelsdir, masksdir, normedpixelsdir, sids,
                               get_sumstats=get_sumstats,
                               normalize=normalize)

# reduce to 5 meta-markers using PCA

In [None]:
# create metapixels for more accurate PCA
metapixels, npixels = ti.metapixels_allsamples(normedpixelsdir, masksdir, sids)

In [None]:
# PCA the metapixels
loadings, C, allmp = ti.pca_metapixels(metapixels.values(), k)
loadings.to_feather(f'{processeddir}/_pcloadings.feather')
del metapixels, allmp; gc.collect()

In [None]:
# apply the PC loadings to plain pixels
allpixels_pca = ti.pca_pixels(normedpixelsdir, masksdir,
                              pd.read_feather(f'{processeddir}/_pcloadings.feather'), sids)
allpixels_pca.to_feather(f'{processeddir}/_allpixels_pca.feather')

# Run harmony on PCA'd pixels

In [None]:
# run harmony
allpixels_pca = pd.read_feather(f'{processeddir}/_allpixels_pca.feather')
ti.harmonize(allpixels_pca, processeddir)

In [None]:
# read in result and write individual samples
harmpixels = pd.read_feather(f'{processeddir}/_allpixels_pca_harmony.feather')
ti.visualize_pixels(harmpixels, 50000, ['sid'])
ti.write_harmonized(masksdir, processeddir, harmpixels, sids)

# Sanity checks

In [None]:
# all pcs of one sample
s = xr.open_dataarray(f'{processeddir}/{sids[0]}.nc').astype(np.float32)
s.plot(col='marker', col_wrap=5, vmin=-3, vmax=3, cmap='seismic')

In [None]:
# histogram of each pc
harmpixels = pd.read_feather(f'{processeddir}/_allpixels_pca_harmony.feather')
plt.figure(figsize=(12,9))
for i in range(k):
    print(i, end='')
    plt.subplot(4, 3, i+1)
    plt.hist(harmpixels.values[:,i], bins=1000)
plt.tight_layout()
plt.show()

In [None]:
# PC1 of several samples
from IPython.display import display, clear_output
fig, axs = plt.subplots(len(sids[::5])//5 + 1, 5, figsize=(16, 4*(len(sids[::5])//5 + 1)))
for sid, ax in zip(sids[::3], axs.flatten()):
    s = xr.open_dataarray(f'{processeddir}/{sid}.nc').astype(np.float32)
    vmax = np.percentile(np.abs(s.sel(marker='hPC1').data), 99)
    s.sel(marker='hPC1').plot(ax=ax, cmap='seismic', vmin=-vmax, vmax=vmax, add_colorbar=False)
    ax.set_title(sid)
    plt.tight_layout(); clear_output(wait=True); display(fig)
    gc.collect()
plt.close()