In [1]:
import argparse
import gc
import numpy as np
import pandas as pd
import logging
import pickle
from datetime import datetime
from os import makedirs
from os.path import join
from pathlib import Path

import yaml
from scipy.sparse import load_npz
import scipy
from sklearn.decomposition import TruncatedSVD

seed = 42
SAVE_DIR = "/scratch/st-jiaruid-1/shenoy/svd-comp/"

In [2]:
paths = {
    'multiome': 
    {        
        'x': '/arc/project/st-jiaruid-1/yinian/multiome/svd-comp/train_input_multiome_svd128.pkl',
        'y': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_multi_targets_values.sparse.npz',
        'x_test': '/arc/project/st-jiaruid-1/yinian/multiome/svd-comp/test_input_multiome_svd128.pkl',
        'x_cols': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_multi_inputs_idxcol.npz',
        'y_cols': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_multi_targets_idxcol.npz',
        'x_test_cols': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/test_multi_inputs_idxcol.npz'
    },
    'cite':
    {
        'x': '/arc/project/st-jiaruid-1/yinian/multiome/svd-comp/train_input_cite_svd128.pkl',
        'y': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_cite_targets_values.sparse.npz',
        'x_test': '/arc/project/st-jiaruid-1/yinian/multiome/svd-comp/test_input_cite_svd128.pkl',
        'x_cols': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_cite_inputs_idxcol.npz',
        'y_cols': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_cite_targets_idxcol.npz',
        'x_test_cols': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/test_cite_inputs_idxcol.npz',
    }
}

In [51]:
modality = 'cite'

In [52]:
# Load Data
x = np.load(paths[modality]["x"], allow_pickle=True)
x_test = np.load(paths[modality]["x_test"], allow_pickle=True)

In [53]:
x_cols = np.load(paths[modality]["x_cols"], allow_pickle=True)
x_test_cols = np.load(paths[modality]['x_test_cols'], allow_pickle=True)

In [54]:
assert x.shape[0] == len(x_cols['index']) and (x_test.shape[0] == len(x_test_cols['index']))

In [55]:
metadata = pd.read_csv('/arc/project/st-jiaruid-1/yinian/multiome/metadata.csv')
metadata_ = metadata[metadata['cell_id'].isin(list(x_cols['index']) + list(x_test_cols['index']))].reset_index(drop=True)

In [80]:
x_df = pd.DataFrame(
    np.vstack([x, x_test]), 
    index=list(x_cols['index']) + list(x_test_cols['index'])
)

In [81]:
x_df.shape

(119651, 128)

In [82]:
add_components = 10

In [83]:
add_df_dict = {}

for day in metadata_['day'].unique():
    for donor in metadata_['donor'].unique():
        # get ids corresponding
        cell_ids = metadata_[(metadata_.day == day) & (metadata_.donor == donor)]['cell_id'].tolist()
        if len(cell_ids) == 0: continue
        print (day, donor, len(cell_ids))
        
        pca = TruncatedSVD(n_components=add_components)
        test_df = pca.fit_transform(x_df.loc[cell_ids])
        add_df_dict[f'{day}_{donor}'] = test_df.mean(0)

2 27678 7476
2 32606 7476
2 13176 6071
2 31800 8395
3 27678 6488
3 32606 6999
3 13176 7643
3 31800 6259
4 27678 7832
4 32606 9511
4 13176 8485
4 31800 10149
7 27678 6247
7 32606 7254
7 13176 7195
7 31800 6171


In [84]:
for i in range(add_components):
    x_df[f'col_{i}'] = 0

In [85]:
metadata_.shape[0] == x_df.shape[0]

True

In [86]:
x_df_numpy = x_df.to_numpy()
print (x_df_numpy.shape)
for i, day in enumerate(metadata_['day'].unique()):
    for donor in metadata_['donor'].unique():
        # do replacement
        if day > 2:
            indices = metadata_[(metadata_.day == day) & (metadata_.donor == donor)].index.tolist()
            if len(indices) == 0: continue
            day_used = metadata_['day'].unique()[i - 1]
            print (f'Using mean from day: {day_used} and donor : {donor}')
            x_df_numpy[indices, 128:] = add_df_dict[f'{day_used}_{donor}']

(119651, 138)
Using mean from day: 2 and donor : 27678
Using mean from day: 2 and donor : 32606
Using mean from day: 2 and donor : 13176
Using mean from day: 2 and donor : 31800
Using mean from day: 3 and donor : 27678
Using mean from day: 3 and donor : 32606
Using mean from day: 3 and donor : 13176
Using mean from day: 3 and donor : 31800
Using mean from day: 4 and donor : 27678
Using mean from day: 4 and donor : 32606
Using mean from day: 4 and donor : 13176
Using mean from day: 4 and donor : 31800


In [87]:
df_sanity_check=pd.DataFrame(
    x_df_numpy,
    index=list(x_cols['index']) + list(x_test_cols['index'])
)

In [88]:
indices = metadata_[(metadata_.day == 2)].index.tolist()

# everything should be 0
assert (df_sanity_check.iloc[indices, 128:] == 0.0).all().all()

In [89]:
x_new, x_test_new = (
    df_sanity_check.iloc[:len(list(x_cols['index']))].to_numpy(),
    df_sanity_check.iloc[len(list(x_cols['index'])):].to_numpy(),
)

In [90]:
assert (x_new.shape[0] == x.shape[0]) and (x_test_new.shape[0] == x_test.shape[0])

In [91]:
input_type = modality
input_dim = 128
pickle.dump(
    x_new,
    open(join(SAVE_DIR, f"train_input_{input_type}_{input_dim}_ctxt_addn_comp_{add_components}.pkl"), "wb"),
)
pickle.dump(
    x_test_new,
    open(join(SAVE_DIR, f"test_input_{input_type}_{input_dim}_ctxt_addn_comp_{add_components}.pkl"), "wb"),
)