Skip to content

Commit

Permalink
Backport PR #2689 on branch 1.1.x (fix(external): ensure solo soft pr…
Browse files Browse the repository at this point in the history
…edict returns probs) (#2690)

Backport PR #2689: fix(external): ensure solo soft predict returns probs

---------

Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 5, 2024
1 parent 5405332 commit af76ed3
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
7 changes: 7 additions & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/
{class}`scvi.autotune.TuneAnalysis` in favor of new experimental functional API with
{func}`scvi.autotune.run_autotune` and {class}`scvi.autotune.AutotuneExperiment` {pr}`2561`.

### 1.1.3 (2024-MM-DD)

#### Fixed

- Fix {meth}`scvi.external.SOLO.predict` to correctly return probabiities instead of logits
when passing in `soft=True` {pr}`2689`.

### 1.1.2 (2024-03-01)

#### Changed
Expand Down
9 changes: 3 additions & 6 deletions scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def predict(self, soft: bool = True, include_simulated_doublets: bool = False) -
Parameters
----------
soft
Return probabilities instead of class label
Return probabilities instead of class label.
include_simulated_doublets
Return probabilities for simulated doublets as well.
Expand All @@ -401,10 +401,7 @@ def predict(self, soft: bool = True, include_simulated_doublets: bool = False) -
DataFrame with prediction, index corresponding to cell barcode.
"""
adata = self._validate_anndata(None)

scdl = self._make_data_loader(
adata=adata,
)
scdl = self._make_data_loader(adata=adata)

@auto_move_data
def auto_forward(module, x):
Expand All @@ -413,7 +410,7 @@ def auto_forward(module, x):
y_pred = []
for _, tensors in enumerate(scdl):
x = tensors[REGISTRY_KEYS.X_KEY]
pred = auto_forward(self.module, x)
pred = torch.nn.functional.softmax(auto_forward(self.module, x), dim=-1)
y_pred.append(pred.cpu())

y_pred = torch.cat(y_pred).numpy()
Expand Down
1 change: 1 addition & 0 deletions scvi/module/_vae.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main module."""

import logging
from collections.abc import Iterable
from typing import Callable, Literal, Optional
Expand Down
15 changes: 12 additions & 3 deletions tests/external/test_solo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from scvi.model import SCVI


def test_solo():
@pytest.mark.parametrize("soft", [True, False])
def test_solo(soft: bool):
n_latent = 5
adata = synthetic_iid()
SCVI.setup_anndata(adata)
Expand All @@ -16,13 +17,21 @@ def test_solo():
solo = SOLO.from_scvi_model(model)
solo.train(1, check_val_every_n_epoch=1, train_size=0.9)
assert "validation_loss" in solo.history.keys()
solo.predict()
predictions = solo.predict(soft=soft)
if soft:
preds = predictions.to_numpy()
assert preds.shape == (adata.n_obs, 2)
assert np.allclose(preds.sum(axis=-1), 1)

bdata = synthetic_iid()
solo = SOLO.from_scvi_model(model, bdata)
solo.train(1, check_val_every_n_epoch=1, train_size=0.9)
assert "validation_loss" in solo.history.keys()
solo.predict()
predictions = solo.predict(soft=soft)
if soft:
preds = predictions.to_numpy()
assert preds.shape == (adata.n_obs, 2)
assert np.allclose(preds.sum(axis=-1), 1)


def test_solo_multiple_batch():
Expand Down

0 comments on commit af76ed3

Please sign in to comment.