In [None]:
"""Compute UMAP embedding for some input+wgbs data in epiatlas and chip-atlas datasets."""
# pylint: disable=import-error, redefined-outer-name, use-dict-literal, too-many-lines, unused-import, unused-argument, too-many-branches

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from __future__ import annotations

import os
import pickle
import subprocess
from importlib import metadata
from pathlib import Path

import numpy as np
import umap
from umap.umap_ import nearest_neighbors

from epi_ml.core.hdf5_loader import Hdf5Loader

In [None]:
input_basedir = Path("/lustre06/project/6007017/rabyj/epilap/input")
# input_basedir = Path("/home/local/USHERBROOKE/rabj2301/mounts/narval-mount/project-rabyj/epilap/input")

In [None]:
chromsize_path = input_basedir / "chromsizes" / "hg38.noy.chrom.sizes"
hdf5_loader = Hdf5Loader(chrom_file=chromsize_path, normalization=True)

In [None]:
# Make temporary file list out of two filelists
hdf5_input_dir = Path(os.environ.get("SLURM_TMPDIR", "/tmp"))

hdf5_lists_dir = input_basedir / "hdf5_list"
epiatlas_filename_list_path = (
    hdf5_lists_dir / "hg38_epiatlas-freeze-v2/100kb_all_none_input_wgbs.list"
)
chip_atlas_filename_list_path = hdf5_lists_dir / "C-A_100kb_all_none_input.list"

for path in [epiatlas_filename_list_path, chip_atlas_filename_list_path]:
    if not path.exists():
        raise FileNotFoundError(f"Could not find {path}")

epiatlas_files = hdf5_loader.read_list(epiatlas_filename_list_path)
chip_atlas_files = hdf5_loader.read_list(chip_atlas_filename_list_path)

epiatlas_filepaths = [
    hdf5_input_dir / "epiatlas_dfreeze_100kb_all_none" / name
    for name in epiatlas_files.values()
]
chip_atlas_filepaths = [
    hdf5_input_dir / "100kb_all_none" / name for name in chip_atlas_files.values()
]
all_paths = epiatlas_filepaths + chip_atlas_filepaths

hdf5_paths_list_path = hdf5_input_dir / "hdf5_paths.list"
with open(hdf5_paths_list_path, "w", encoding="utf8") as f:
    f.writelines([str(path) + "\n" for path in all_paths])

In [None]:
# Untar data from both tars into local node tmpdir, and create list of files that takes into account different folder structure for each
chip_atlas_tar_path = Path(
    "/lustre07/scratch/rabyj/other_data/C-A/hdf5/100kb_all_none.tar"
)
epiatlas_tar_path = Path(
    "/lustre06/project/6007515/ihec_share/local_ihec_data/epiatlas/hg38/hdf5/epiatlas_dfreeze_100kb_all_none.tar"
)

for path in [chip_atlas_tar_path, epiatlas_tar_path]:
    if not path.exists():
        raise FileNotFoundError(f"Could not find {path}")

for path in [chip_atlas_tar_path, epiatlas_tar_path]:
    subprocess.run(["tar", "-xf", str(path), "-C", str(hdf5_input_dir)], check=True)

In [None]:
# hdf5_paths_list_path = "/home/local/USHERBROOKE/rabj2301/Projects/epiclass/input/hdf5_list/100kb_all_none_50samples.list"

In [None]:
# Load relevant files
hdf5_dict = hdf5_loader.load_hdf5s(
    data_file=hdf5_paths_list_path,
    strict=True,
).signals

In [None]:
embedding_params = {
    "standard": {
        "n_neighbors": 15,
        "min_dist": 0.1,
        "n_components": 3,
        "metric": "precomputed",
        "low_memory": False,
    },
    "cluster": {
        "n_neighbors": 30,
        "min_dist": 0,
        "n_components": 10,
        "metric": "precomputed",
        "low_memory": False,
    },
    "densmap": {
        "n_neighbors": 30,
        "min_dist": 0.1,
        "n_components": 3,
        "metric": "precomputed",
        "low_memory": False,
        "densmap": True,
    },
}

In [None]:
try:
    output_dir = chip_atlas_tar_path.parent / "umap"
    output_dir.mkdir(exist_ok=True)
except NameError:
    output_dir = Path.home()

In [None]:
data = np.array(list(hdf5_dict.values()), dtype=np.float32)

In [90]:
precomputed_knn = nearest_neighbors(
    X=data,
    n_neighbors=30,
    metric="correlation",
    random_state=42,
    low_memory=False,
    metric_kwds=None,
    angular=None,
)
with open(output_dir / "precomputed_knn.pkl", "wb") as f:
    pickle.dump(precomputed_knn, f)

# Save requirements so pickle is never lost in the future
dists = metadata.distributions()
with open(output_dir / "pickle_requirements.txt", "w", encoding="utf8") as f:
    for dist in dists:
        name = dist.metadata["Name"]
        version = dist.version
        f.write(f"{name}=={version}\n")

In [None]:
file_names = list(hdf5_dict.keys())
for name, params in embedding_params.items():
    embedding = umap.UMAP(
        **params, random_state=42, precomputed_knn=precomputed_knn
    ).fit_transform(X=data)
    with open(output_dir / f"embedding_{name}.pkl", "wb") as f:
        pickle.dump({"ids": file_names, "embedding": embedding, "params": params}, f)
        print(f"Saved embedding_{name}.pkl")