I am trying to create my own PyTorch module by looking at the documentation as in Crafting the Module in vanilla PyTorch. In the next step, I want to implement the model as in Constructing a high-level model.
In the first linked page, the loss function has a line, in cell 6,
log_lik = NegativeBinomial(total_counts=theta, total=nb_logits).log_prob(x).sum(dim=-1)
whereas the PyTorch documentation for the function is given as
torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)
So, in the next stage, when I am replacing the built-in VAE model with MyModule, I am getting the unexpected keyword argument 'total_counts' error.
I am attaching the file for reproducing the issue.
scvi_tut.py.txt
I am using PyTorch 1.10.0+cpu.