From aeeb9c74fb7811fbea3bb189179f1ebe861d76ef Mon Sep 17 00:00:00 2001 From: Valeh Valiollah Pour Amiri <4193454+watiss@users.noreply.github.com> Date: Wed, 29 Dec 2021 09:30:29 -0800 Subject: [PATCH 1/2] Fix `total_counts` to `total_count` Fix `total_counts` to `total_count` in the module_user_guide --- module_user_guide.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module_user_guide.ipynb b/module_user_guide.ipynb index f31b51a2..6743ed72 100644 --- a/module_user_guide.ipynb +++ b/module_user_guide.ipynb @@ -411,7 +411,7 @@ " # the pytorch NB distribution uses a different parameterization\n", " # so we must apply a quick transformation (included in scvi-tools, but here we use the pytorch code)\n", " nb_logits = (px_rate + 1e-4).log() - (theta + 1e-4).log()\n", - " log_lik = NegativeBinomial(total_counts=theta, total=nb_logits).log_prob(x).sum(dim=-1) \n", + " log_lik = NegativeBinomial(total_count=theta, total=nb_logits).log_prob(x).sum(dim=-1) \n", "\n", " # term 2\n", " prior_dist = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))\n", From f729c111f37d3538bcac25a41e030d9b5123da62 Mon Sep 17 00:00:00 2001 From: Valeh Valiollah Pour Amiri <4193454+watiss@users.noreply.github.com> Date: Sun, 2 Jan 2022 10:35:07 -0800 Subject: [PATCH 2/2] Replace `total` with `logits` --- module_user_guide.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/module_user_guide.ipynb b/module_user_guide.ipynb index 6743ed72..62c9558e 100644 --- a/module_user_guide.ipynb +++ b/module_user_guide.ipynb @@ -411,7 +411,7 @@ " # the pytorch NB distribution uses a different parameterization\n", " # so we must apply a quick transformation (included in scvi-tools, but here we use the pytorch code)\n", " nb_logits = (px_rate + 1e-4).log() - (theta + 1e-4).log()\n", - " log_lik = NegativeBinomial(total_count=theta, total=nb_logits).log_prob(x).sum(dim=-1) \n", + " log_lik = NegativeBinomial(total_count=theta, logits=nb_logits).log_prob(x).sum(dim=-1) \n", "\n", " # term 2\n", " prior_dist = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))\n",