Skip to content

Commit

Permalink
train larry multilineage
Browse files Browse the repository at this point in the history
  • Loading branch information
jiachen committed Apr 27, 2023
1 parent 16669e3 commit 9316c95
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 2 deletions.
21 changes: 21 additions & 0 deletions pyrovelocity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ def create_reports_config(model_name: str, model_number: int):
process_method="load_data",
process_args=dict(),
),
larry_multilineage=create_dataset_config(
"larry_multilineage",
dl_root="${data_external.root_path}",
data_file="larry_mono.h5ad",
rel_path="${data_external.root_path}/larry_mono.h5ad",
url="${data_external.pyrovelocity.sources.figshare_root_url}/37028572",
process_method="load_data",
process_args=dict(),
),
),
),
model_training=dict(
Expand All @@ -213,6 +222,7 @@ def create_reports_config(model_name: str, model_number: int):
"larry_model2",
"larry_mono_model2",
"larry_neu_model2",
"larry_multilineage_model2",
],
simulate_model1=create_model_config(
"simulate",
Expand Down Expand Up @@ -317,6 +327,17 @@ def create_reports_config(model_name: str, model_number: int):
offset=True,
max_epochs=1000,
),
larry_multilineage_model2=create_model_config(
"pyrovelocity",
"larry_multilineage",
2,
"emb",
svi_train=True,
batch_size=4000,
cell_state="state_info",
offset=True,
max_epochs=1000,
),
),
reports=dict(
model_summary=dict(
Expand Down
17 changes: 15 additions & 2 deletions pyrovelocity/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,22 @@ def load_data(
adata = scv.datasets.dentategyrus()
elif data == "larry":
adata = load_larry()
elif data in ['larry_mono', 'larry_neu']:
adata = load_unipotent_larry(data.split('-')[1])
elif data in ["larry_mono", "larry_neu"]:
adata = load_unipotent_larry(data.split("-")[1])
adata = adata[adata.obs.state_info != "Centroid", :]

Check warning on line 95 in pyrovelocity/data.py

View check run for this annotation

Codecov / codecov/patch

pyrovelocity/data.py#L94-L95

Added lines #L94 - L95 were not covered by tests
elif data == "larry_multilineage":
adata_mono = load_unipotent_larry("mono")
adata_mono_C = adata_mono[adata_mono.obs.state_info != "Centroid", :].copy()
adata_neu = load_unipotent_larry("neu")
adata_neu_C = adata_neu[adata_neu.obs.state_info != "Centroid", :].copy()
adata_multilineage = adata_mono.concatenate(adata_neu)
adata = adata_mono_C.concatenate(adata_neu_C)
adata.layers["raw_spliced"] = adata_multilineage[

Check warning on line 103 in pyrovelocity/data.py

View check run for this annotation

Codecov / codecov/patch

pyrovelocity/data.py#L97-L103

Added lines #L97 - L103 were not covered by tests
adata.obs_names, adata.var_names
].layers["spliced"]
adata.layers["raw_unspliced"] = adata_multilineage[

Check warning on line 106 in pyrovelocity/data.py

View check run for this annotation

Codecov / codecov/patch

pyrovelocity/data.py#L106

Added line #L106 was not covered by tests
adata.obs_names, adata.var_names
].layers["unspliced"]
else:
adata = sc.read(data)

Expand Down
46 changes: 46 additions & 0 deletions reproducibility/figures/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ data_external:
process_method: load_data
process_args: {}
rel_path: data/processed/larry_neu_processed.h5ad
larry_multilineage:
data_file: larry_mono.h5ad
dl_root: data/external
dl_path: data/external/larry_mono.h5ad
rel_path: data/external/larry_mono.h5ad
url: https://ndownloader.figshare.com/files/37028572
derived:
process_method: load_data
process_args: {}
rel_path: data/processed/larry_multilineage_processed.h5ad
model_training:
train:
- simulate_model1
Expand All @@ -126,6 +136,7 @@ model_training:
- larry_model2
- larry_mono_model2
- larry_neu_model2
- larry_multilineage_model2
simulate_model1:
path: models/medium_model1
model_path: models/medium_model1/model
Expand Down Expand Up @@ -511,6 +522,41 @@ model_training:
cell_specific_kinetics: null
kinetics_num: 2
loss_plot_path: models/larry_neu_model2/loss_plot.png
larry_multilineage_model2:
path: models/larry_multilineage_model2
model_path: models/larry_multilineage_model2/model
input_data_path: data/processed/larry_multilineage_processed.h5ad
trained_data_path: models/larry_multilineage_model2/trained.h5ad
pyrovelocity_data_path: models/larry_multilineage_model2/pyrovelocity.pkl
metrics_path: models/larry_multilineage_model2/metrics.json
run_info_path: models/larry_multilineage_model2/run_info.json
vector_field_parameters:
basis: emb
training_parameters:
_target_: pyrovelocity.api.train_model
_partial_: true
guide_type: auto
model_type: auto
svi_train: true
batch_size: 4000
train_size: 1.0
use_gpu: 0
likelihood: Poisson
num_samples: 30
log_every: 100
cell_state: state_info
patient_improve: 0.0001
patient_init: 45
seed: 99
lr: 0.01
max_epochs: 1000
include_prior: true
library_size: true
offset: true
input_type: raw
cell_specific_kinetics: null
kinetics_num: 2
loss_plot_path: models/larry_multilineage_model2/loss_plot.png
reports:
model_summary:
summarize:
Expand Down
96 changes: 96 additions & 0 deletions reproducibility/figures/dvc.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1858,3 +1858,99 @@ stages:
- path: models/larry_neu_model2/trained.h5ad
md5: 200056b02081a4560ceaceec98b138d7
size: 136435223
preprocess_larry_multilineage:
cmd: python preprocess.py data_external.sources=[pyrovelocity] data_external.pyrovelocity.process=[larry_multilineage]
deps:
- path: data/external/larry_mono.h5ad
md5: 01f4e084c37482e26800ba4dfa0202bd
size: 66173538
- path: data/external/larry_neu.h5ad
md5: 3192e2fe89d64f5d0d158c0e7d26c79d
size: 60008807
- path: preprocess.py
md5: bf09c86fc25b1d1a98b9aa1c5fa04361
size: 2757
params:
config.yaml:
base:
log_level: INFO
data_external.pyrovelocity.larry_multilineage:
data_file: larry_mono.h5ad
dl_root: data/external
dl_path: data/external/larry_mono.h5ad
rel_path: data/external/larry_mono.h5ad
url: https://ndownloader.figshare.com/files/37028572
derived:
process_method: load_data
process_args: {}
rel_path: data/processed/larry_multilineage_processed.h5ad
outs:
- path: data/processed/larry_multilineage_processed.h5ad
md5: ae00da2f12ab951745640e24a7a33045
size: 139326026
train_larry_multilineage_model2:
cmd: /usr/bin/time -v python train.py model_training.train=[larry_multilineage_model2]
deps:
- path: data/processed/larry_multilineage_processed.h5ad
md5: ae00da2f12ab951745640e24a7a33045
size: 139326026
- path: train.py
md5: 0e7c46a112eab9290b48ad4f3deecaaa
size: 8684
params:
config.yaml:
model_training.larry_multilineage_model2:
path: models/larry_multilineage_model2
model_path: models/larry_multilineage_model2/model
input_data_path: data/processed/larry_multilineage_processed.h5ad
trained_data_path: models/larry_multilineage_model2/trained.h5ad
pyrovelocity_data_path: models/larry_multilineage_model2/pyrovelocity.pkl
metrics_path: models/larry_multilineage_model2/metrics.json
run_info_path: models/larry_multilineage_model2/run_info.json
vector_field_parameters:
basis: emb
training_parameters:
_target_: pyrovelocity.api.train_model
_partial_: true
guide_type: auto
model_type: auto
svi_train: true
batch_size: 4000
train_size: 1.0
use_gpu: 0
likelihood: Poisson
num_samples: 30
log_every: 100
cell_state: state_info
patient_improve: 0.0001
patient_init: 45
seed: 99
lr: 0.01
max_epochs: 1000
include_prior: true
library_size: true
offset: true
input_type: raw
cell_specific_kinetics:
kinetics_num: 2
loss_plot_path: models/larry_multilineage_model2/loss_plot.png
outs:
- path: models/larry_multilineage_model2/loss_plot.png
md5: 906f8f1c5ea6df8d0c58bd7d7a78455c
size: 12683
- path: models/larry_multilineage_model2/metrics.json
md5: eac6e405df47e2a0132613b55d53541c
size: 160
- path: models/larry_multilineage_model2/model
md5: 1abc8d5bfae1c2872cfe11ef4f955a82.dir
size: 610922
nfiles: 1
- path: models/larry_multilineage_model2/pyrovelocity.pkl
md5: fc35521f8582e1f6496c19062af2f379
size: 21028511
- path: models/larry_multilineage_model2/run_info.json
md5: 6bf9cb75cd7f4a80e99ba6e00a58e228
size: 465
- path: models/larry_multilineage_model2/trained.h5ad
md5: 479152b0bfd2dcfd479ae9f44afc52fe
size: 158935226
37 changes: 37 additions & 0 deletions reproducibility/figures/dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ stages:
- ${data_external.pyrovelocity.larry_neu.derived.rel_path}
# persist: true

preprocess_larry_multilineage:
cmd: python preprocess.py data_external.sources=[pyrovelocity] data_external.pyrovelocity.process=[larry_multilineage]
deps:
- preprocess.py
- ${data_external.pyrovelocity.larry_mono.rel_path}
- ${data_external.pyrovelocity.larry_neu.rel_path}
params:
- config.yaml:
- base
- data_external.pyrovelocity.larry_multilineage
outs:
- ${data_external.pyrovelocity.larry_multilineage.derived.rel_path}
# persist: true

train_simulate_model1:
cmd: /usr/bin/time -v python train.py model_training.train=[simulate_model1]
deps:
Expand Down Expand Up @@ -505,6 +519,29 @@ stages:
- ${model_training.larry_neu_model2.pyrovelocity_data_path}
# persist: true

train_larry_multilineage_model2:
cmd: /usr/bin/time -v python train.py model_training.train=[larry_multilineage_model2]
deps:
- train.py
- ${model_training.larry_multilineage_model2.input_data_path}
params:
- config.yaml:
- model_training.larry_multilineage_model2
metrics:
- ${model_training.larry_multilineage_model2.metrics_path}:
cache: false
outs:
- ${model_training.larry_multilineage_model2.run_info_path}:
cache: false
- ${model_training.larry_multilineage_model2.training_parameters.loss_plot_path}
# persist: true
- ${model_training.larry_multilineage_model2.trained_data_path}
# persist: true
- ${model_training.larry_multilineage_model2.model_path}
# persist: true
- ${model_training.larry_multilineage_model2.pyrovelocity_data_path}
# persist: true

figure2:
cmd: python fig2/figure.py
deps:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/loss_plot.png
/model
/pyrovelocity.pkl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"-ELBO": -0.9092934969035932,
"MAE": 0.29600671132732154,
"FDR_HMP": 7.114115202502107e-10,
"FDR_sig_frac": 0.831,
"real_epochs": 1000.0
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"artifact_uri": "file:///home/jupyter/pyrovelocity/reproducibility/figures/mlruns/0/4343894a07214a58b2ef6bd29c8bd049/artifacts",
"end_time": 1682558487401,
"experiment_id": "0",
"lifecycle_stage": "active",
"run_id": "4343894a07214a58b2ef6bd29c8bd049",
"run_name": "larry_multilineage_model2-4343894",
"run_uuid": "4343894a07214a58b2ef6bd29c8bd049",
"start_time": 1682557974119,
"status": "FINISHED",
"user_id": "jupyter"
}

0 comments on commit 9316c95

Please sign in to comment.