Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small CPU model forward pass extremely slow #13757

Closed
hyparxis opened this issue Nov 9, 2018 · 6 comments
Closed

Small CPU model forward pass extremely slow #13757

hyparxis opened this issue Nov 9, 2018 · 6 comments
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: mkl Related to our MKL support module: multithreading Related to issues that occur when running on multiple CPU threads module: openmp Related to OpenMP (omp) support in PyTorch module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hyparxis
Copy link

hyparxis commented Nov 9, 2018

Issue description

I have a 80x256x256x10 FC network I'm using for policy gradient, but when I do a forward pass on it it takes 70-100ms (!!) to execute. Puzzling-ly, after some condition that I can't identify, the forward pass suddenly speeds up permanently to < 1-3ms (I couldn't get a reproducible example of that though). Exiting the code during execution seems to imply that it's spending a lot of time in torch.addmm(bias, input, weight.t()). Adding pytorch.set_num_threads(1) seems to fix it, but I don't know why. Sorry if this is a duplicate (I suspect it may be) but I looked and couldn't find anything in the issues

Code example

Here's a (sort of) minimal example that reproduces the slowness:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time

class GaussianMLP(nn.Module):
    def __init__(self, num_inputs, action_dim, init_std=1, learn_std=True, nonlinearity="tanh"):
        super(GaussianMLP, self).__init__()

        actor_dims = (256, 256)
        critic_dims = (256, 256)

        # create actor network
        self.actor_layers = nn.ModuleList()
        self.actor_layers += [nn.Linear(num_inputs, actor_dims[0])]
        for l in range(len(actor_dims) - 1):
            in_dim = actor_dims[l]
            out_dim = actor_dims[l + 1]
            self.actor_layers += [nn.Linear(in_dim, out_dim)]
        
        self.mean = nn.Linear(actor_dims[-1], action_dim)

        # create critic network
        self.critic_layers = nn.ModuleList()
        self.critic_layers += [nn.Linear(num_inputs, critic_dims[0])]
        for l in range(len(critic_dims) - 1):
            in_dim = critic_dims[l]
            out_dim = critic_dims[l + 1]
            self.critic_layers += [nn.Linear(in_dim, out_dim)]

        self.vf = nn.Linear(critic_dims[-1], 1)

        if nonlinearity == "relu":
            self.nonlinearity = F.relu
        else:
            self.nonlinearity = torch.tanh
        
        self.train()

    def forward(self, inputs):
        start = time.time()
        x = inputs
        for l in self.critic_layers:
            x = self.nonlinearity(l(x))
        value = self.vf(x)

        x = inputs
        for l in self.actor_layers:
            x = self.nonlinearity(l(x))
        x = self.mean(x)

        x = torch.tanh(x)
        print(time.time() - start)

        return value, x

policy = GaussianMLP(80, 10, nonlinearity="relu", init_std=np.exp(-2), learn_std=False)

for _ in range(100):
    s = torch.rand(1, 80)
    policy(s)

And here's an iPython stacktrace from randomly exiting that code while it's running:

    61 for _ in range(100):
     62     s = torch.rand(1, 80)
---> 63     policy(s)

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/p-morais/cassie-sim-to-real/gaussian_mlp.py in forward(self, inputs)
     49         x = inputs
     50         for l in self.actor_layers:
---> 51             x = self.nonlinearity(l(x))
     52         x = self.mean(x)
     53 

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/.local/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     53 
     54     def forward(self, input):
---> 55         return F.linear(input, self.weight, self.bias)
     56 
     57     def extra_repr(self):

~/.local/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1022     if input.dim() == 2 and bias is not None:
   1023         # fused op is marginally faster
-> 1024         return torch.addmm(bias, input, weight.t())
   1025 
   1026     output = input.matmul(weight.t()))

System Info

PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: TITAN X (Pascal)
Nvidia driver version: 390.77
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy (1.15.3)
[pip3] torch (1.0.0a0+c029c83)
[pip3] torchvision (0.2.1)
[conda] pytorch 0.4.0 py36_cuda8.0.61_cudnn7.1.2_1 pytorch
[conda] pytorch-cpu 0.4.1 py36_cpu_1 pytorch
[conda] torchfile 0.1.0
[conda] torchvision-cpu 0.2.1 py36_1 pytorch

@colesbury
Copy link
Member

Thanks for reporting this, especially the repro code.

The issue is that the MKL library (which we use for matrix-multiplication) is creating and destroying threads for every call to addmm (in nn.Linear). The thread creation is more expensive than the actual matrix-multiplication.

For now, you can disable this by setting the environment variable MKL_DYNAMIC=false.

@soumith perhaps we should set mkl_set_dynamic(false) by default during PyTorch initialization if the environment variable isn't set. MKL assumes it's true by default, but I think that's bad for us since we have operations outside of MKL that use OpenMP.

@colesbury colesbury added bug module: performance Issues related to performance, either of kernel code or framework glue labels Nov 9, 2018
@colesbury
Copy link
Member

Also, this looks like a problem particular to GNU's OpenMP. Clang/LLVM's implementation doesn't seem to suffer from.

@cpuhrsch cpuhrsch removed the bug label Apr 9, 2019
@ezyang ezyang added module: cpu CPU specific problem (e.g., perf, algorithm) module: mkl Related to our MKL support module: openmp Related to OpenMP (omp) support in PyTorch module: multithreading Related to issues that occur when running on multiple CPU threads triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 10, 2019
@ezyang
Copy link
Contributor

ezyang commented Apr 10, 2019

This was fixed in #13868

@imaginary-person
Copy link
Contributor

imaginary-person commented Mar 1, 2021

The issue is that the MKL library (which we use for matrix-multiplication) is creating and destroying threads for every call to addmm (in nn.Linear). The thread creation is more expensive than the actual matrix-multiplication.

@colesbury, even after mkl_set_dynamic was set to false in #13868, a similar issue was reported in PyTorch 1.4.0 ( #32008). The OpenMP num_threads clause doesn't make unused threads of the thread-pool exit, but as @CaoZhongZ reported, the OpenMP threads that were unused in at::parallel_for were exiting & MKL was creating them again, leading to a memory leak.

It was observed that disabling dynamic-scaling of OpenMP threads fixed the issue, so that workaround was used. However, the reason as to why unused OpenMP threads were exiting seems to be unknown, and might be of interest to you.

In PyTorch 1.4.0, the threads of the OpenMP thread-pool were actually created when desired. For instance, only one additional OpenMP thread was created (so, a total of 2 threads in the OpenMP thread-pool) for the following snippet of code:

import torch
torch.eye(252)

On the other hand, the current implementation creates the threads of the OpenMP thread pool when the first parallel operation is to be performed, so it'd create threads equal to the number of physical cores, but use only 2 for this snippet.

Do you happen to have any insights on why MKL could be causing an issue in #32008, and why disabling auto-scaling of OpenMP threads would've fixed the issue? I tried to reproduce the issue but couldn't. Thank you!

@ezyang
Copy link
Contributor

ezyang commented Mar 1, 2021

@imaginary-person if this is still affecting master I suggest putting this in a new issue.

@imaginary-person
Copy link
Contributor

imaginary-person commented Mar 1, 2021

@imaginary-person if this is still affecting master I suggest putting this in a new issue.

Thanks for your response, @ezyang!

@peterbell10 reported today in #52815 that if dynamic scaling of OpenMP threads (not assigning wasteful work to some threads of the OpenMP thread-pool) is enabled in at::parallel_for by modifying the master branch's source-code, it's still causing memory leaks due to rapid thread creation & destruction on AMD machines even with a fill_ operation, when MKL isn't used.

I'll try to get access to an AMD machine so that I can try to figure out why that's happening & will submit a GNU bug report or patch, and if necessary, will submit an issue here.

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: mkl Related to our MKL support module: multithreading Related to issues that occur when running on multiple CPU threads module: openmp Related to OpenMP (omp) support in PyTorch module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants