-
Notifications
You must be signed in to change notification settings - Fork 565
Open
Labels
Description
🐛 Bug
Training time increases exponentially with increase in number of layers.
To Reproduce
The following code takes around 16 seconds per step at layers=1000, where it takes 0.38 seconds per step at layers=500.
import os
os.environ["XLA_REGISTER_INSTALLED_PLUGINS"] = "1"
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"
import torch
import torch_xla.core.xla_model as xm
import time
class SimpleModel(torch.nn.Module):
def __init__(self, layers=1000):
super().__init__()
self.input_layer = torch.nn.Linear(10, 100)
self.input_activation = torch.nn.ReLU()
self.hidden_layers = torch.nn.ModuleList()
for _ in range(layers):
self.hidden_layers.append(torch.nn.Linear(100, 100))
self.hidden_layers.append(torch.nn.ReLU())
self.output_layer = torch.nn.Linear(100, 1)
def forward(self, x):
x = self.input_activation(self.input_layer(x))
for layer in self.hidden_layers:
x = layer(x)
x = self.output_layer(x)
return x
def train():
device = xm.xla_device()
for layers in [500, 1000, 1500, 2000, 2500, 3000]:
model = SimpleModel(layers).to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.MSELoss()
step_times = []
for i in range(10):
start_time = time.time()
input = torch.randn(10).to(device)
y = torch.randn(1).to(device)
output = model(input)
total_loss = loss_fn(output, y)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
xm.mark_step()
xm.wait_device_ops()
step_time = time.time() - start_time
step_times.append(step_time)
median_time = sorted(step_times)[len(step_times)//2]
print(f"Median step time: {median_time:.4f}s")
print(step_times)
if __name__ == "__main__":
train()
Expected behavior
Wondering if it is expected that the training step time increases exponentially (rather than linearly) with more layers.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
- torch_xla version: 2.5.1
Additional context
Reactions are currently unavailable