Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running batch transforms (e.g. torch.nn.functional.grid_sample) is slower on TPU vs CPU #2405

Open
butchland opened this issue Aug 5, 2020 · 17 comments
Assignees
Labels
nostale Do not consider for staleness op lowering

Comments

@butchland
Copy link

butchland commented Aug 5, 2020

馃悰 Bug

Executing the batch transforms which use the torch.nn.functional.grid_sample function
seems to run slower on a single TPU core vs the CPU.

To Reproduce

We encountered this weird bug where the batch transforms seem to run slower on a single TPU core compared to a GPU (which we kind of expected) but we also found out that it runs even slower than the CPU!

Here's some notebooks showing the results for a single transform (Flip)

GPU (fastest) - avg time: 0.021 secs
CPU (middle) - avg time: 1.227 secs
TPU (slowest) - avg time: 7.341 secs

For the torch.nn.functional F.grid_sample method, times:
GPU - avg time: 0.000 *not measurable in time.time() diff
CPU - avg time: 0.821 secs
TPU - avg time: 4.247 secs

This is not even using gradients, just pure parallel tensor computations...

The notebooks here have a run on colab link in them so you can validate
the stats produced above.

https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_CPU.ipynb

https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_GPU.ipynb

https://github.com/butchland/fastai_xla_extensions/blob/master/archive_nbs/09_vision_augment_experiments_profile_TPU.ipynb

Expected behavior

We expect that the TPU should run the transforms much faster than a CPU.

Environment

Colab

  • Reproducible on XLA backend [CPU/TPU]: TPU runtime - pytorch-dev20200707
  • torch_xla version: torch-xla==1.6+5430aca

Additional context

Lastly, we noticed that other data augmentations that run on batch on the TPU doesn't slow it
down as much (brightness and contrast) as they run faster on a TPU vs a CPU...

We (@butchland and @tyoc213) are building an extension library to enable the fastai library to run on TPUs.

If you have suggestions to speed it up (e.g. alternative algos for batch transforms
for data augmentations), we'd appreciate it!

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 5, 2020

Hi @butchland, thanks for reporting! Could you follow the instruction in here to run a debug run? This way we can know what exactually happened. My guess would be that xla currently does not lower grid_sampler_2d and grid_sampler_3d node so they are being forwarded to the CPU which caused the slowdown.

@tyoc213
Copy link
Contributor

tyoc213 commented Aug 5, 2020

debug_run.tar.gz If you need any other thing or need extra parameters we can send it back.

deleted log because I just see that it is in the zip

The finale python code executed is this

import fastai_xla_extensions.core
from fastai2.vision.all import *
from my_timesaver_utils.profiling import *
path = untar_data(URLs.PETS)/'images'
Path.BASE_PATH = path; path.ls()
print(f'running on default_device() & cuda is {torch.cuda.is_available()}')

img = PILImage.create(path/'Abyssinian_1.jpg')
resize = Resize(size=200)
img2 = resize(img,split_idx=0)




timg2 = TensorImage(array(img2)).permute(2,0,1).float()/255.

def batch_ex(bs, device): return TensorImage(timg2[None].to(device).expand(bs, *timg2.shape))


b768_img = batch_ex(768, default_device()); (b768_img.shape, b768_img.device)


flip_tfm = Flip(p=1.0)
# run without profile
run_with_profile = True
F.grid_sample = profile_call(F.grid_sample) if run_with_profile else F.grid_sample

@profile_call
def mtest(b_img):
    #set_trace()
    new_b_img = flip_tfm(b_img)
    return new_b_img
    
clear_prof_data()
print("--- 10 image tensor loops:")
for i in range(10):
    print("--- ---------------------------------")
    new_b768_img = mtest(b768_img)
print("--- ")
print_prof_data()

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 5, 2020

Oh, so it looks like you are running a small code snippet and does not finish a full step so the metric report is not generated. Do you mind running it again with

import torch_xla.debug.metrics as met

print(met.metrics_report())

at the end. More detail can be find here. This report will be super helpful and telling us where the slowness coming from.

@JackCaoG JackCaoG self-assigned this Aug 5, 2020
@tyoc213
Copy link
Contributor

tyoc213 commented Aug 5, 2020

O yes, we tried to remove all the other interference from extra code and just limit it the most to what is causing the slowness.

debug_run_stats .tar.gz

By the way, I see I can add --hlo and generate maybe something like grab_graph.py something?


By the way, in our sample what was missing to generate this report so that we dont print it manually?


Found that aten are the the calls forwarded to CPU because not implemented on TPU, so I paste from the tgz for easy access.

Counter: aten::_local_scalar_dense
  Value: 10
Counter: aten::affine_grid_generator
  Value: 10
Counter: aten::grid_sampler_2d
  Value: 10

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 5, 2020

yup, you are right.. _local_scalar_dense most likely comes from pytorch item() call. The other two looks like we need to add a lowering. We are a bit busy with the upcoming release now but will add this to our todo list.

For your other questions, yes you can setup XLA_SAVE_TENSORS_FILE and XLA_SAVE_TENSORS_FMT to dump the hlo text for the debug run as well. I think we will dump the metric report every time mark_step is called in here. You can also manually call the api to print the metric to the output.

@JackCaoG JackCaoG added nostale Do not consider for staleness op lowering labels Aug 5, 2020
@tyoc213
Copy link
Contributor

tyoc213 commented Aug 6, 2020

Good! I see, thanks.

So we should wait for the lowering of this 2 calls, but what about item() call, is there a way we can optimize it? or it will show up always?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Aug 6, 2020

yup I will update this thread when I make any progress on lowering these two ops. We have a section in here talking about item call, the take away is don't use it unless necessary.

@tyoc213
Copy link
Contributor

tyoc213 commented Apr 13, 2021

Hi there @JackCaoG, Im back on this issue so I will give a try!

@Pwang001
Copy link

Hi there @JackCaoG, Im back on this issue so I will give a try!

Hi, get a solution new? I just met the same issue when training GPT-neo model using TPUs on Colab

@dhruvrnaik
Copy link

I am using Resize inside Pytorch lightning training step and it makes my code terribly slow. Is there a solution for this?

@JackCaoG
Copy link
Collaborator

affine_grid should be supported now. To get a better understanding of the problem, doing a metric report

import torch_xla.debug.metrics as met

print(met.metrics_report())

after a step will help

@dhruvrnaik
Copy link

@JackCaoG I am using Resize in my training step, more specifically transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=None).

I am using the metric report function, but the training seems to be stuck at the resize(tensor) operation, so the code doesn't reach that step.

@honglin-chen
Copy link

Hi there @JackCaoG, Im back on this issue so I will give a try!

Hi @JackCaoG @butchland @tyoc213, I'm wondering if you have find the solution to speeding up the F.grid_sample method. I'm also running into same issue. Any help will be much appreciated.

@JackCaoG
Copy link
Collaborator

@dhruvrnaik if you have a small repo I might be able to take a look. It depends on what op that transforms.Resize get decompose by pytorch and passed to us.

@JackCaoG
Copy link
Collaborator

Taking another look of the cpu grid_sampler_2d implementation which seems to be play with the stride quite a bit. This kind of op is pretty difficult to lower for xla since we can't play with the stride(xla is a functional compiler and does not expose its memory space to the user). We will have to use a bunch of reshape, conv to fake the stride. Which models uses this batch transforms? I don't have anyone can immediately work on this lowering.

@honglin-chen
Copy link

Thank you for looking into it, @JackCaoG. I'm working with a model for learning pixel-conditioned Neural Radiance Field (paper, code). Many radiance field models heavily rely on nn.functional.grid_sample to sample features based on pixel/voxel locations, so having XLA support for it will be tremendously beneficial to the community. If this function will not be supported soon, do you have recommendations for alternative approaches that might serve the same functionality, while running much faster on TPU?

@JackCaoG
Copy link
Collaborator

I don't think we have anything similar to grid_sample. PyTorch/XLA supports upsample_nearest2d but I felt like that's not what you want. I was trying to find if there is a tensorflow implementation of this op(tf op usually has xla lowering) but I only found this workaround. We might be able to use some of the technique in this tf workaround when lowering this op..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
nostale Do not consider for staleness op lowering
Projects
None yet
Development

No branches or pull requests

6 participants