Skip to content

Train loop takes exponentially longer as number of layers increase #8824

@rplsbo

Description

@rplsbo

🐛 Bug

Training time increases exponentially with increase in number of layers.

To Reproduce

The following code takes around 16 seconds per step at layers=1000, where it takes 0.38 seconds per step at layers=500.

import os

os.environ["XLA_REGISTER_INSTALLED_PLUGINS"] = "1"
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"

import torch
import torch_xla.core.xla_model as xm
import time

class SimpleModel(torch.nn.Module):
    def __init__(self, layers=1000):
        super().__init__()
        self.input_layer = torch.nn.Linear(10, 100)
        self.input_activation = torch.nn.ReLU()
        
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(layers):
            self.hidden_layers.append(torch.nn.Linear(100, 100))
            self.hidden_layers.append(torch.nn.ReLU())
        
        self.output_layer = torch.nn.Linear(100, 1)
    
    def forward(self, x):
        x = self.input_activation(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

def train():
    device = xm.xla_device()
    for layers in [500, 1000, 1500, 2000, 2500, 3000]:
        model = SimpleModel(layers).to(device)

        optimizer = torch.optim.Adam(model.parameters())
        loss_fn = torch.nn.MSELoss()

        step_times = []

        for i in range(10):
            start_time = time.time()
            input = torch.randn(10).to(device)
            y = torch.randn(1).to(device)
            output = model(input)
            total_loss = loss_fn(output, y)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            xm.mark_step()
            xm.wait_device_ops()
            step_time = time.time() - start_time
            step_times.append(step_time)
        
        median_time = sorted(step_times)[len(step_times)//2]
        print(f"Median step time: {median_time:.4f}s")
        print(step_times)
    

if __name__ == "__main__":
    train()

Expected behavior

Wondering if it is expected that the training step time increases exponentially (rather than linearly) with more layers.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 2.5.1

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions