-
Notifications
You must be signed in to change notification settings - Fork 561
Closed
Description
❓ Questions and Help
Hello, all!
I am running a 3D CNN in TPU v3-8, and the computation seems to be not well optimized.
In short, the majority of the computation time seems to be wasted due to excessive padding in my first convolution.
Background Information
- PyTorch-XLA version: 1.9 (installed by the prebuilt docker
gcr.io/tpu-pytorch/xla:r1.9) - PyTorch version: 1.9.0a0+git1a7c23c
- TPU version: v3-8 (software version: pytorch-1.9)
- I have profiled the computation with TensorBoard, following the official profiling guide.
- I have found weird padding remarks, which is a possible source of low utilization as introduced here.
Observation
Below is the screenshot of the TensorBoard profiling result (op_profile page)
- Note that BATCH / FEATURE dimensions are padded, and the wasted time is 27% of all time.
Below is the PyTorch definition of the very first convolution:
self.in_ch = 64
self.inc = nn.Sequential(nn.Conv3d(n_channels, self.in_ch, 7, padding=3), nn.BatchNorm3d(64), nn.ReLU(inplace=True))
- The convolution has
12x1x96x96x96(BCTHW) as the input, with 7x7x7 3D convolution, 64 output channels, and 3px paddings. - The overall TPU FLOPS utilization is 13%, and memory bandwidth utilization is 21% (at the top of op_profile page)
- The maximum batch size I can use is 18, but the utilization is still very low.
Question
- Am I interpreting the result correctly? It seems that there are a lot of room to optimize.
- Is there any best-practice for optimizing this low utilization issue?
- Is 3D convolution fully optimized in PyTorch-XLA (I assume that 2D conv must have been fully optimized)?
Please excuse my ignorance, as I am just a beginner of using PyTorch-XLA / TPU.
Any help or suggestions would be appreciated. Thank you in advance!
Metadata
Metadata
Assignees
Labels
No labels
