New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reconstruction loss dependence on expression values #729
Comments
I tried doing an MMD version by swapping the def gaussian_kernel(x, y):
xs = x.sample()
ys = y.sample()
kernel = torch.exp(-((xs - ys)**2))
return kernel
def compute_mmd(x, y):
x_kernel = gaussian_kernel(x, x)
y_kernel = gaussian_kernel(y, y)
xy_kernel = gaussian_kernel(x, y)
mmd = x_kernel + y_kernel - 2*xy_kernel
return mmd resulting in a similar loss (and trend with mean expression values): Any thoughts on this? |
Hi, Thanks for being in touch. Indeed, this behavior makes sense to me. There are several possibilities if this is pathological in your precise analysis.
My two cents here. The sensitivity each gene in the latent space is dependent on the parametrization of the generative model, and in particular on how the mean of the negative binomial is defined as a function of the latent space. In another model with @PierreBoyeau, we have looked into alternate parameterizations, in particular for differential expression, in which we noticed that our method was having a false discovery rate dependent on the mean of the gene. In scVI, the mean of the negative binomial is defined as: \mu_NB = l * softmax( nn(z) ) where nn(z) is a multi layer neural network, with ReLu activations everywhere except at the end, where there is no activation. What you might want to try is the following: log \mu_NB = log l + log g + nn(z), where g is a gene specific scaling factor, that can be either random or fixed. g can be set for example to the mean of the gene in the dataset. In this case, we expect that nn(z) will contain less dependent to gene specific intensity. It is susceptible to solve your problem. |
Please let us know if we can help in any way. Also, this should be only a minor change in the VAE class. |
Thanks for the prompt and insightful response @romain-lopez! Few things that I missed out in my details earlier, my apologies: I started with PBMC3k with only highly variable features as the input. It's an old dataset but also the one where a lot of ground truth (and subtle) differences well characterized. My input dataset was 2.6k cells x 1.8k genes. Thanks for clarifying the role of MMD loss, I extrapolated its relevance to sequencing depth. Your suggestion is very interesting and aligned with what I had in mind to relax the loss dependence on highly expressed genes (which in other words is similar to adding an offset in the GLM). This would essentially involve writing an extended Encoder and Decoder class (like |
Hi @saketkc that would be one way to do it, though a cleaner version might add options to the VAE class. I could see a new arg specifying the per gene per batch offset, and maybe the cell offset could be the library size computed on the fly. As we are going through major changes I'm not sure the best way. I could see a PR being helpful after our 1.0 release perhaps if you find this model useful in your tests. |
Thanks @adamgayoso and @romain-lopez for your comments! I took a first pass at it by modifying the reconstruction error (setting both def get_reconstruction_loss(
self, x, px_rate, px_r, px_dropout, mean_g, **kwargs
) -> torch.Tensor:
# Reconstruction Loss
px_rate_ = px_rate
if self.adjust_gene_counts:
gene_mean = torch.mean(x, dim=0)
px_rate_ = torch.exp(torch.log(px_rate_+1e-8) - torch.log(gene_mean+1e-8))
if self.adjust_cell_counts:
cell_mean = torch.mean(x, dim=1)
px_rate_ = torch.exp(torch.log(px_rate_+1e-8) - torch.log(cell_mean+1e-8)[:, None])
if self.reconstruction_loss == "zinb":
reconst_loss = (
-ZeroInflatedNegativeBinomial(
mu=px_rate_, theta=px_r, zi_logits=px_dropout
)
.log_prob(x)
.sum(dim=-1)
)
elif self.reconstruction_loss == "nb":
reconst_loss = (
-NegativeBinomial(mu=px_rate_, theta=px_r).log_prob(x).sum(dim=-1)
)
elif self.reconstruction_loss == "poisson":
reconst_loss = -Poisson(px_rate_).log_prob(x).sum(dim=-1)
return reconst_loss It seems to work(?) at the gene level, though the separation is not great with 1 layer: or even 3 layers (I didn't try more): That said (and assuming my code changes make sense for the offset case), there is still "some" dependence on the overall expression of the cell which I find a bit surprising given that it was already being modeled anyway: For some reason, the dependence on both gene and cell expression values fails to go away when using LDVAE, though I get slightly better clusters: I am not quite sure I understand why the dependence of reconstruction error still remains unaffected for the LDVAE approach. Probably my approach is incorrect (?) On a second thought about factoring in expression/cell counts, this approach, even if correct, is relying on the assumption that the slope between NB log_mean and log_total_UMI is 1 (equivalent to an offset in GLM "y/umi ~ ..."), which is not necessarily true for all genes (and hence is often modeled as a covariate "y ~ log_umi + ..".). That said, I imagined that the NN(z) part as @romain-lopez wrote, should be able to capture this anyway, but doesn't seem to. Would love to know your thoughts. Thanks once again for finding time to respond, I understand that 1.0 must be keeping everyone busy. |
So the implementation you have isn't quite what I was envisioning. Basically,
Now with the way you have the gene offset implemented, it's computed only on a minibatch of cells, so it's a noisy estimate of the offset you'd want. I would compute the offsets and feed them into the init of the VAE class. This also isn't relevant with this dataset but this offset should probably depend on the batch. Now regarding your experiments, it's hard for me to understand why we should expect a good result to be contingent on separation of the CD8 effector population. Did these labels come from Seurat/scTransform pipeline? Could you simulate some data where genes are IID with some different gene offset and show that the latent space does or does not have dependence on the mean expression? This is interesting and we appreciate any findings you make and share here! |
Thanks @adamgayoso! I will first answer your penultimate question first and get back to you with the other two a bit later. The cell types did come from Seurat/SCTransform workflow, but can be verified to be biologically "true" either with the markers or other datasets. Here is a comparison of standard normalization vs SCT normalization: Detailed notebook here: https://nbviewer.jupyter.org/gist/saketkc/18cb0b435eec82d9927f1c1c052c3ce8 |
sure, at least it's more stable than "true". Romain and I had a chat, we'd also like to note that
|
Thanks for your comments @adamgayoso and @romain-lopez. I think I did manage to implement what you had in mind. Notebook here. The gene names in red are marker genes between at least one of the two clusters from "CD4 T Eff", "CD4 T Nai", "CD8 Mem", "CD4 T Mem", "CD8 Eff" which I will call COI (clusters of interest) here. After playing around with different modes (combination of cell and gene offset), I think the reconstruction error dependence on expression is not the constraining factor for preventing the separation in COI. To summarize, default VAE (npcs=50 throughout): LDVAE default (bias off): LDVAE with cell counts offset: Happy to know your thoughts/comments. If my implementation makes sense and you believe it might still be useful in |
Do you expect the dispersion estimated by VAE/LDVAE (cc @vals) to be generally squished to a narrow range? dropout, means, dispersions = scvi_posterior.generate_parameters()
for _ in range(99):
dropout, means, dispersions = scvi_posterior.generate_parameters()
dropout /= 100
means /= 100
dispersions /= 100 |
Yeah I think that is pretty expected for most genes. In addition to accounting for the count depth with the library size values the latent representation will account for a lot of the over-dispersion you would see in a mean-vs-variance plot of the genes. There probably isn't mush overdispersion "left" in the data after that. |
I actually encountered this phenomenon regarding library size being correlated with ELBO for a different project. I do believe that it is a function of the NB distribution. It's easier to give things higher probability near zero than spread the probability across more values, if that makes sense. Regarding population separation, I kind of wonder what would happen if you use ZINB instead of NB. As I think this is a 10x v1 dataset, perhaps giving the model more flexibility could help separate that subpopulation (dropout is a function of z). |
I've been thinking a bit about this too. I think it makes sense in general for count distributions have higher likelihood when counts are higher. In a way each count is a replicated observation. Counting 1,000 things should mean more evidence than when you're counting 100 things. The downside then is that it becomes hard to compare evidence from count observations with evidence from continuous measurements. Does anyone know if this has studied somewhere? I was thinking of a toy example. Imagine the first initial invention of PCA, when Pearson did the equivalent of regression but using two noisy observations in 1901. Now think of a case where you want to correlate count observations with continuous observations (which are both noisy). The model is easy to write out: x1 ~ N(z * b1, sigma), x2 ~ Poisson(exp(z * b2)). But whether the model is dominated by the continuous observations or the count observations will depend on the magnitudes of the counts. Again, has anyone seen this in literature? |
@vals has an interesting question! Some tangential observations: It is possible to relax this dependence in some sense by regressing out n_umis following VAE: (Sorry about the oversized plot, copying it directly from a notebook) I am still unsure if this dependence on the counts is really causing the continuous observations to not cluster as tightly as I would expect them to. For example LDVAE with nb loss on the same data as the one above seems to have tigher clusters even though its count loss has a heavy trend: |
In [1]: from scvi.core._distributions import NegativeBinomial
In [2]: import torch
In [3]: NegativeBinomial(mu = 0.2, theta=1).log_prob(torch.tensor(1.))
Out[3]: tensor(-1.9741)
In [4]: NegativeBinomial(mu = 100, theta=1).log_prob(torch.tensor(100.))
Out[4]: tensor(-5.6101) Not sure if this helps explain the phenomenon (code from our new api, not sure what the import for the NB is on stable at the moment, but this class exists). It could be worth looking into the "gene-cell" option for dispersion. |
On a different note, the library size effect in the latent space could be mitigated by
|
Thanks for your suggestions @adamgayoso. All your suggestions are great and I have tried most of these in the past except playing around with the Just for the sake of clarity, I will summarize the steps one by one. Please note that I will stick to NB loss, as the ZINB argument is not very strong. I use 12k genes to start off with and these are chosen so that SCTransform and scVI get same # of genes as inputs (even if being fed SCT corrected counts). I have also set a smaller point size for the umap plots as compared to all the previous ones I posted for visual clarity. 1. PBMC3k raw input, dispersion="gene", NB 2. PBMC3k raw input, dispersion="gene-cell", NB 3. SCTCorrected input, dispersion="gene", NB Notice that the scales in the left figure is now squished as compared to corresponding figures in 1. and 2. 4. SCTCorrected input, dispersion="gene-cell", NB There is still some dependence of the loss on the total UMIs as seen in the left (though as in 3, the scales are much narrower). 5. SCTCorrected input, dispersion="gene-cell" + library size offset, NB 6. PBMC3k raw input, dispersion="gene-cell" + library size offset, NB The libsize plots highlight how the model is essentially "flattening" it out given that it is now being inferred as a fixed variable using the offset. It does seem that the 'offset' term (similar to the UMI term with a 'non fixed slope' in SCtransform) is helpful (at least in this setting). |
Sorry for the delayed response, I'll have to look over this when I have more time. Quick thought -- how you are giving the model sctransform corrected input but using the counts for the decoder? Edit: @saketkc , it's hard for me to follow all the experiments going on here. When you say SCTcorrected, does that mean you trained the vae like normal and passed SCTcorrected after training? How exactly is the cell count offset implemented? Would you be able to share code? |
No worries @adamgayoso! The notebook and code is here: https://gist.github.com/saketkc/e376435e297fe4b34de61956d17abd44 I created a About the SCTcorrected input: It involves no change in the two examples I show in the above notebook. The only difference is the input counts are corrected output from SCTransform. I also implemented a pseudo-regularized version for restricting the dispersion (in this case the |
Why should we assume the Pearson residuals are NB distributed? It seems to make sense to me to use SCT as input to the encoder only, counts at decoder Re: cell offset px_rate = (
(torch.exp(library) * (cell_offset)) * px_scale * gene_offset
) # torch.clamp( , max=12) I don't see the benefit in using both our latent library size and the cell offset. It should be the case that the latent library is meaningless when using a cell offset, but it's complicated because it has a prior distribution on it, so it's hard to interpret. So to summarize this issue, we have two problems:
Regarding (2), it also has not been assessed whether the genes used in scVI are the same in the standard Seurat workflow. Gene selection could be critical in separating subpopulations. Regarding (1), I believe we have demonstrated elsewhere (and Valentine has) that scVI's posterior predictive distribution does indeed capture the data characteristics well. It's also difficult to map the relationship between log likelihood and performance in downstream tasks. That said, a |
I apologize for not being clear in my earlier response. First of all, I think scVI is pretty cool and extremely useful! In no way I meant to downplay its usefulness. Sorry if my response indicated otherwise. I am just thinking of ways to enhance some of scVI's capabilities for my downstream tasks.
My input is not pearson residuals, but corrected counts where the correct is applied by replacing the logUMI with median values throughout and reversing the regression.
Very interesting point, something I have thought of before but never tried but it's in my todo list.
I agree. One thing you would notice when I use the
Regarding (2), I totally agree. I focused only on the top 3000 hvg genes (
Of course, and there is no doubt about that. I am not so concerned about the absolute log likelihood. I am trying to still wrap my head around the loss-count dependence and the implication it has on downstream clustering. Maybe the dependence does not affect downstream tasks (as evaluated on some continuous measurement) for example the Zhengmix4equal set in the default mode: or using cell offset: In one of my test runs, I got the following relation working with raw counts and 3000 hvg genes which was different from what I have seen so far. Note that the number of cells are too few (246, SMARTER protocol - Kumar dataset) so likely scVI is underfitting (as mentioned in the paper), but the effect of total UMIs is not as pronounced (notice the three libsize clusters in the left plot): the reconstruction loss is better with cell offset (and leads to slightly tighter clusters though the effect is not pronounced): Similar results with other dataset with fewer cells (531, Koh - SMARTER protocol): An underfit here still produces expected clusters though in both cases. In my experiments with multiple datasets 1) |
I would like to make a correction to one of my earlier comments about scVI's dispersion being restricted to small range. This arose because of a division error (10 instead of 100). Sorry for the confusion. It is actually the other way round: I was curious if this high level of dispersion is causing the clustering difference between say SCT and scVI, so I clamped the dispersions: as compared to the default scVI + raw count offset: as compared to default scVI: In order to promote some "anti-clustering", I downweight the KL divergence by a beta value, but that was not helpful (I guess you have already explored this): last subplot is SCT for reference. Maybe this is already known but was definitely not my impression earlier: It seems the reconstruction error (and its dependence) is informative on its own. UMAP reconstruction errors reflects the latent space: |
It's not clear that the scvi dispersions should be equal (or comparable) to those inferred by SCT as overdispersion in the scvi model is established also through integration over z. This becomes a tricky problem comparing marginal and conditional distributions. Would you be able to share the data you're working with along with annotations, perhaps in h5ad format? Does this one population separate in the "standard" workflow? Scran workflow? I'm asking because if it's only apparent in the SCT workflow, it seems to me like it could be an artifact (i.e., ideally biologically "true" clusters should be somewhat stable across methods despite the interest in the field to outperform each other). It sounds like you're also interested in improving cell type resolution in the latent space, which may or may not be related to the dependence of the log likelihood on expression values, dispersion, etc. Would looking at, e.g., the silhouette with respect to the cell type labels be more informative than the UMAPs? |
That's correct. I was approaching it from a regularization point of view, mostly to figure out what goes on in the latent space if SCT and scVI dispersion are equal. I think a lot of confusion in my posts is arising because I have probably not done a good job at explaining the rationale behind all this, which I shall try to do now. I am asking a simple question: can the normalization that helps SCTransform sharpen biological cell type resolution (see Hafemeister & Satija, 2019) be incorporated in scVI (which has all the advantages already, particularly mini batches that are my favorite). I specifically started working with PBMC3k, though it is now an old dataset, because a) the cell types are well characterized and b) SCT's normalization is indeed able to resolve them. This is in comparison to the standard log normalization which fails to do so. See for example (or a better version in the vignette linked below): (based on this vignette). So, this is one dataset where we know this is definitely not an artifact (and hence the choice for this case study). The original publication used PBMC33k and 10k.
This is PBMC3k with manual annotations. I have put both raw counts with metadata and SCT corrected counts with metadata here.
I agree that looking at UMAPs might not be the best strategy here. But I chose UMAPs because here we have an expectation of the resulting UMAPs to look more structured. Yes, the motivation is indeed better resolution and a more quantitative measure would definitely help, but given the grounds set by log normalization vs SCT figure above, I believe a UMAP suffices. Just to elaborate on why I forced the dispersion estimates to be similar between scVI and SCT - this was a thought experiment to see if clamping dispersion can be somehow absorbed by the weights of NN or the latent libsize implicitly. Once the dispersion estimates are similar, I expect the NN to act as a vanilla GLM. I can now ask, what are those genes/cells that are driving this difference. To do so, I extracted the cell x gene reconstruction matrix and asked what genes are hard to reconstruct. Turns out it is the highly expressed genes: For example MALAT1, that is ubiquitously and highly expressed has the highest reconstruction error of all genes. Similarly, the ribosomal genes are the ones that are at the top of list. This to me is still counter-intuitive. I know the basic assumption that controlling the dispersion makes scVI a simple GLM is most likely not correct, but why is it still hard for scVI to reconstruct genes that are highly (and ubiquitously) expressed. I also compare scVI's imputed to SCT's denoised values (note: denoising is performed by performing a PCA on corrected counts) in the bottom plot (which is simply saying that where scVI's reconstruction error is high, the difference between SCT corrected and scVI imputed is also higher). Here's the plot to summarize my points: |
Here's an example on the Zhengmix8eq dataset (where the truth is known): on another simulated dataset: PBMC3k with scVI dispersions clamped to that of SCT: |
Hi @saketkc, it's a bit hard to track this issue on github. I'd be happy to move the discussion over to our discourse. |
Sure @adamgayoso. I will close this one and open a discussion on discourse with some new insights I have. Also curious if totalVI now allows a count offset? |
Using
LDVAE
/VAE
, it appears as if the reconstruction loss is dependent on the expression level of the geneor the cell

while at the same time the latent space with 5 encoder layers seems to blend in the cell types which can otherwise be separated by regressing out the total depth and using pearson residuals. It appears as if the VAE is emphasizing too much on highly expressed genes which seems to affect the latent space (at least in this case). The paper mentions using an MMD loss might be helpful here, but I believe it is currently not part of scVI?
The text was updated successfully, but these errors were encountered: