Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] - Per sample gradients using function transforms not working for RNN #2566

Closed
bnuliujing opened this issue Sep 22, 2023 · 1 comment
Closed
Labels

Comments

@bnuliujing
Copy link

Add Link

Hello!
I'm working on a optimization algorithm that requires computing the per sample gradients. Assuming the batch size is $N$ and the number of model parameters is $M$, I want to calculate $\partial \log p(\mathbf{x}^{(i)};\theta)/\partial \theta_j$, which is an $N \times M$ matrix. I found the [PER-SAMPLE-GRADIENTS](https://pytorch.org/tutorials/intermediate/per_sample_grads.html) tutorial and began my own experiments. As a proof of concept, I defined a generative model with a tractable likelihood, such as MADE (Masked Autoencoder for Distribution Estimation), PixelCNN, RNN, etc., and sepcified the log_prob and sample methods. I utilized the function transforms methods mentioned in the tutorial, but currently, it only works for MADE (I believed it would work for NADE and PixelCNN too, since these models need only one forward pass to calculate the log likelihood of $\mathbf{x}$. For RNN however, both sampling and inference require $N$ forward pass).
Below, I've provided my code snippets, and I'm interested in figuring out why it's not working for RNN. Making it work for RNN would significantly reduce the number of parameters for my research purpose.
Thank you!

Describe the bug

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)


class MADE(nn.Module):
    '''A simple one-layer MADE (Masked Autoencoder for Distribution Estimation)'''

    def __init__(self, n=10, device='cpu', *args, **kwargs):
        super().__init__()
        self.n = n
        self.device = device

        self.weight = nn.Parameter(torch.randn(self.n, self.n) / math.sqrt(self.n))
        self.bias = nn.Parameter(torch.zeros(self.n))
        mask = torch.tril(torch.ones(self.n, self.n), diagonal=-1)
        self.register_buffer('mask', mask)

    def pred_logits(self, x):
        return F.linear(x, self.mask * self.weight, self.bias)

    def forward(self, x):
        logits = self.pred_logits(x)
        log_probs = - F.binary_cross_entropy_with_logits(logits, x, reduction='none')
        return log_probs.sum(-1)

    @torch.no_grad()
    def sample(self, batch_size):
        x = torch.zeros(batch_size, self.n, dtype=torch.float, device=self.device)
        for i in range(self.n):
            logits = self.pred_logits(x)[:, i]
            x[:, i] = torch.bernoulli(torch.sigmoid(logits))
        return x


class GRUModel(nn.Module):
    '''GRU for density estimation'''

    def __init__(self, n=10, input_size=2, hidden_size=8, device='cpu'):
        super().__init__()
        self.n = n
        self.input_size = input_size  # input_size=2 when x is binary
        self.hidden_size = hidden_size
        self.device = device
        self.gru_cell = nn.GRUCell(self.input_size, self.hidden_size)
        self.fc_layer = nn.Linear(self.hidden_size, 1)

    def pred_logits(self, x, h=None):
        x = torch.stack([x, 1 - x], dim=1)  # 1 -> (1, 0), 0 -> (0, 1), (batch_size, 2)
        h_next = self.gru_cell(x, h)  # h_{i+1}
        logits = self.fc_layer(h_next).squeeze(1)
        return h_next, logits

    def forward(self, x):
        log_prob_list = []
        x = torch.cat([torch.zeros(x.shape[0], 1, dtype=torch.float, device=self.device), x], dim=1)  # cat x_0
        h = torch.zeros(x.shape[0], self.hidden_size, dtype=torch.float, device=self.device)  # h_0
        for i in range(self.n):
            h, logits = self.pred_logits(x[:, i], h)
            log_prob = - F.binary_cross_entropy_with_logits(logits, x[:, i + 1], reduction='none')
            log_prob_list.append(log_prob)
        return torch.stack(log_prob_list, dim=1).sum(dim=1)

    @torch.no_grad()
    def sample(self, batch_size):
        x = torch.zeros(batch_size, self.n + 1, dtype=torch.float, device=self.device)
        for i in range(self.n):
            h, logits = self.pred_logits(x[:, i], h=None if i == 0 else h)
            x[:, i + 1] = torch.bernoulli(torch.sigmoid(logits))
        return x[:, 1:]


if __name__ == '__main__':
    model = MADE()
    # model = GRUModel()

    # Sample from the generative model
    samples = model.sample(128)

    # Then I use the function transforms methods mentioned in the tutorial
    # to calculate the per sample mean
    from torch.func import functional_call, grad, vmap
    params = {k: v.detach() for k, v in model.named_parameters()}

    def loss_fn(log_probs):
        return log_probs.mean(0)

    def compute_loss(params, sample):
        batch = sample.unsqueeze(0)
        log_prob = functional_call(model, (params,), (batch,))
        loss = loss_fn(log_prob)
        return loss

    ft_compute_grad = grad(compute_loss)
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0))
    ft_per_sample_grads = ft_compute_sample_grad(params, samples)

    print(ft_per_sample_grads)

The above code works for MADE (I also check the values of gradients, they are correct!)
However, when I use model = GRUModel(), an error arises:

Traceback (most recent call last):
  File "per_sample_grads.py", line 100, in <module>
    ft_per_sample_grads = ft_compute_sample_grad(params, samples)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
    results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 1245, in wrapper
    output = func(*args, **kwargs)
  File "per_sample_grads.py", line 94, in compute_loss
    log_prob = functional_call(model, (params,), (batch,))
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/utils/stateless.py", line 262, in _functional_call
    return module(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "per_sample_grads.py", line 63, in forward
    h, logits = self.pred_logits(x[:, i], h)
  File "per_sample_grads.py", line 54, in pred_logits
    h_next = self.gru_cell(x, h)  # h_{i+1}
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/liujing/anaconda3/envs/dev/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 1327, in forward
    ret = _VF.gru_cell(
RuntimeError: output with shape [1, 8] doesn't match the broadcast shape [128, 1, 8]

Describe your environment

  • Platform: macOS
  • No CUDA
  • PyTorch version: 2.0.1

The above code is also tested on Ubuntu 18.04 with PyTorch 2.0.1, CUDA 11.7/11.8.

@bnuliujing bnuliujing added the bug label Sep 22, 2023
@svekars svekars added question and removed bug labels Oct 26, 2023
@svekars
Copy link
Contributor

svekars commented Oct 26, 2023

Hi @bnuliujing, can you please post in https://dev-discuss.pytorch.org?

@svekars svekars closed this as completed Oct 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants