In [None]:
import argparse
import gc
import numpy as np
import logging
import pickle
from datetime import datetime
from os import makedirs
from os.path import join
from pathlib import Path

import yaml
from scipy.sparse import load_npz
import scipy
from sklearn.decomposition import TruncatedSVD

seed = 42
SAVE_DIR = "/scratch/st-jiaruid-1/shenoy/svd-comp/"

In [None]:
paths = {
    'multiome': 
    {
      'x': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_multi_inputs_values.sparse.npz',
      'y': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_multi_targets_values.sparse.npz',
      'x_test': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/test_multi_inputs_values.sparse.npz'
    },
    'cite':
    {
      'x': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_cite_inputs_values.sparse.npz',
      'y': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/train_cite_targets_values.sparse.npz',
      'x_test': '/arc/project/st-jiaruid-1/yinian/multiome/sparse-data/test_cite_inputs_values.sparse.npz'
    }
}

In [None]:
modality = 'multiome'

In [None]:
# Load Data
x = load_npz(paths[modality]["x"])
x_test = load_npz(paths[modality]["x_test"])

### Only keep `indices`

In [None]:
x_stacked = scipy.sparse.vstack([x, x_test])

In [None]:
print (x_stacked.shape)

In [None]:
feature_std = np.std(x.toarray(), axis=0)

In [None]:
threshold = 0.2
indices = [i for i, x in enumerate(np.squeeze(feature_std)) if x > threshold]

In [None]:
len(indices)

#### Perform PCA

In [None]:
comp = 200

In [None]:
# perform preprocessing
# transform x and x_test
pca_x = TruncatedSVD(
    n_components=comp,
    random_state=seed,
)

x_transformed = pca_x.fit_transform(x_stacked[:, indices])

In [None]:
x_train_transformed = x_transformed[: x.shape[0], :]
x_test_transformed = x_transformed[x.shape[0] :, :]    
del x, x_test
gc.collect()

In [None]:
# save the processed arrays
input_dim = comp
input_type = modality
pickle.dump(
    x_train_transformed,
    open(join(SAVE_DIR, f"train_input_{input_type}_{input_dim}_mod.pkl"), "wb"),
)
pickle.dump(
    x_test_transformed,
    open(join(SAVE_DIR, f"test_input_{input_type}_svd{input_dim}_mod.pkl"), "wb"),
)