Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training fails (sometimes) when using several GPUs #58

Closed
CesarLeblanc opened this issue Sep 27, 2023 · 4 comments
Closed

Training fails (sometimes) when using several GPUs #58

CesarLeblanc opened this issue Sep 27, 2023 · 4 comments

Comments

@CesarLeblanc
Copy link

Dear maintainers,

I've been using your package for a while now (especially for the FTT model). I've never encountered any trouble, and it helped me boost my performance on a tabular dataset.
Recently, I've been doing some ablation studies, i.e., I've removed some input features to check if the performance of the model would decrease (and if yes, at what point). I've discovered that, when using several GPUs (2x RTX 2080 Ti), the training fails when it has a certain number of input features (but not always, it really depends of the number of input features).
I'm using torch.nn.DataParallel to implement data parallelism, and for some reason related to my framework I don't wish to use torch.nn.parallel.DistributedDataParallel.

Here's a minimal reproducible example to prove my point:

import scipy.sparse  # scipy is 1.10.1
import torch  # torch is 1.13.1
import rtdl  # rtdl is 0.0.13
import torchmetrics  # torchmetrics is 0.11.4
import multiprocessing
import numpy as np  # numpy is 1.24.3

class SparseDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X  # Sparse input data
        self.y = y  # Target data

    def __getitem__(self, index):
        X = torch.from_numpy(self.X[index].toarray()[0]).float()  # Convert the sparse input data to a dense tensor
        y = self.y[index]  # Target data
        return X, y

    def __len__(self):
        return self.X.shape[0]  # Number of items in the dataset

class FTT(rtdl.FTTransformer):
    def __init__(self, n_num_features=None, cat_cardinalities=None, d_token=16, n_blocks=1, attention_n_heads=4, attention_dropout=0.3, attention_initialization='kaiming', attention_normalization='LayerNorm', ffn_d_hidden=16, ffn_dropout=0.1, ffn_activation='ReGLU', ffn_normalization='LayerNorm', residual_dropout=0.0, prenormalization=True, first_prenormalization=False, last_layer_query_idx=[-1], n_tokens=None, kv_compression_ratio=0.004, kv_compression_sharing='headwise', head_activation='ReLU', head_normalization='LayerNorm', d_out=None):
        feature_tokenizer = rtdl.FeatureTokenizer( 
            n_num_features=n_num_features,
            cat_cardinalities=cat_cardinalities,
            d_token=d_token
        )
        transformer = rtdl.Transformer(
            d_token=d_token,
            n_blocks=n_blocks,
            attention_n_heads=attention_n_heads,
            attention_dropout=attention_dropout,
            attention_initialization=attention_initialization,
            attention_normalization=attention_normalization,
            ffn_d_hidden=ffn_d_hidden,
            ffn_dropout=ffn_dropout,
            ffn_activation=ffn_activation,
            ffn_normalization=ffn_normalization,
            residual_dropout=residual_dropout,
            prenormalization=prenormalization,
            first_prenormalization=first_prenormalization,
            last_layer_query_idx=last_layer_query_idx,
            n_tokens=n_num_features+1,
            kv_compression_ratio=kv_compression_ratio,
            kv_compression_sharing=kv_compression_sharing,
            head_activation=head_activation,
            head_normalization=head_normalization,
            d_out=d_out
        )
        super(FTT, self).__init__(feature_tokenizer, transformer)

    def forward(self, x_num=None, x_cat=None):
        x = self.feature_tokenizer(x_num, x_cat)
        print(f"First forward step --> shape is {x.shape} & device is {x.device}")  # For debug 1/7
        x = self.cls_token(x)
        print(f"Second forward step --> shape is {x.shape} & device is {x.device}")  # For debug 2/7
        x = self.transformer(x)
        print(f"Third forward step --> shape is {x.shape} & device is {x.device}")  # For debug 3/7
        return x

num_train_samples = 803473  # Number of samples in the real training dataset
num_test_samples = 82787  # Number of samples in the real testing dataset
num_input_features = 152  # Number of input features in the real dataset
#num_input_features = 252  # If we put another number of features, less say 100 more, then the code works (it depends of the number we put here)
num_classes = 228  # Number of classes in the real dataset

X_train = scipy.sparse.random(num_train_samples, num_input_features, density=0.01, format='csr')
y_train = np.random.randint(0, num_classes, num_train_samples)
X_test = scipy.sparse.random(num_test_samples, num_input_features, density=0.01, format='csr')
y_test = np.random.randint(0, num_classes, num_test_samples)

train_dataset = SparseDataset(X_train, y_train)  # Create a train dataset
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=512, shuffle=True, num_workers=multiprocessing.cpu_count())  # Create a train DataLoader 
test_dataset = SparseDataset(X_test, y_test)  # Create a test dataset
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=512, shuffle=True, num_workers=multiprocessing.cpu_count())  # Create a test DataLoader 

model = FTT(n_num_features=num_input_features, d_out=num_classes)

model = torch.nn.DataParallel(model).cuda()  # Run the model parallelly and move it to GPU
criterion = torch.nn.CrossEntropyLoss().cuda()  # Instantiate loss class and move it to GPU
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)  # Instantiate optimizer class
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5) # Instantiate step learning scheduler class

print(model)  # For debug 4/7
for i in model.named_parameters():
    print(f"{i[0]} -> {i[1].device}")  # For debug 5/7

for epoch in range(100):
    for batch, labels in train_dataloader:  # Iterate through train dataset
        batch = batch.requires_grad_().cuda()  # Load batches with gradient accumulation capabilities
        print(f"Batch shape is {batch.shape} & device is {batch.device}")  # For debug 6/7
        labels = labels.cuda()  # Use GPU for tensors
        print(f"Labels shape is {labels.shape} & device is {labels.device}")  # For debug 7/7
        optimizer.zero_grad()  # Clear gradients w.r.t. parameters
        ########################################
        ### THE CODE FAILS ON THE LINE BELOW ###
        ########################################
        outputs = model(x_num=batch, x_cat=None)  # Forward pass to get output/logits
        ########################################
        ### THE CODE FAILS ON THE LINE ABOVE ###
        ########################################
        loss = criterion(outputs, labels)  # Calculate Loss: softmax --> cross entropy loss
        loss.backward()  # Getting gradients w.r.t. parameters
        optimizer.step()  # Updating parameters
    metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)  # Instantiate the accuracy metric
    metric.cuda()  # Move the metric to the device
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        for batch, labels in test_dataloader:  # Iterate through the batches in the test dataset
            batch = batch.cuda()  # Move the batch to the device
            labels = labels.cuda()  # Move the labels to the device
            outputs = model(x_num=batch, x_cat=None)  # Forward pass to get output/logits
            accuracy = metric(outputs, labels)  # Calculate the accuracy
    accuracy_epoch = metric.compute() * 100  # Compute the overall accuracy for the epoch
    model.train()

    scheduler.step(accuracy)
    print(f'Epoch {epoch} completed.')
    print(f'Accuracy: {accuracy_epoch:.4f}%.')

When my training data (a CSR matrix) has a shape of (803473, 152) (i.e., 803473 samples with each 152 features), this code fails (on multi-GPU). However, if I have a training data of shape (803473, 252) (I just tried a random number), then it works smoothly.

Here are the logs:

/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 16, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(
/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
DataParallel(
  (module): FTT(
    (feature_tokenizer): FeatureTokenizer(
      (num_tokenizer): NumericalFeatureTokenizer()
    )
    (cls_token): CLSToken()
    (transformer): Transformer(
      (blocks): ModuleList(
        (0): ModuleDict(
          (attention): MultiheadAttention(
            (W_q): Linear(in_features=16, out_features=16, bias=True)
            (W_k): Linear(in_features=16, out_features=16, bias=True)
            (W_v): Linear(in_features=16, out_features=16, bias=True)
            (W_out): Linear(in_features=16, out_features=16, bias=True)
            (dropout): Dropout(p=0.3, inplace=False)
          )
          (ffn): FFN(
            (linear_first): Linear(in_features=16, out_features=32, bias=True)
            (activation): ReGLU()
            (dropout): Dropout(p=0.1, inplace=False)
            (linear_second): Linear(in_features=16, out_features=16, bias=True)
          )
          (attention_residual_dropout): Dropout(p=0.0, inplace=False)
          (ffn_residual_dropout): Dropout(p=0.0, inplace=False)
          (output): Identity()
          (ffn_normalization): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
          (key_compression): Linear(in_features=153, out_features=0, bias=False)
          (value_compression): Linear(in_features=153, out_features=0, bias=False)
        )
      )
      (head): Head(
        (normalization): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (activation): ReLU()
        (linear): Linear(in_features=16, out_features=228, bias=True)
      )
    )
  )
)
module.feature_tokenizer.num_tokenizer.weight -> cuda:0
module.feature_tokenizer.num_tokenizer.bias -> cuda:0
module.cls_token.weight -> cuda:0
module.transformer.blocks.0.attention.W_q.weight -> cuda:0
module.transformer.blocks.0.attention.W_q.bias -> cuda:0
module.transformer.blocks.0.attention.W_k.weight -> cuda:0
module.transformer.blocks.0.attention.W_k.bias -> cuda:0
module.transformer.blocks.0.attention.W_v.weight -> cuda:0
module.transformer.blocks.0.attention.W_v.bias -> cuda:0
module.transformer.blocks.0.attention.W_out.weight -> cuda:0
module.transformer.blocks.0.attention.W_out.bias -> cuda:0
module.transformer.blocks.0.ffn.linear_first.weight -> cuda:0
module.transformer.blocks.0.ffn.linear_first.bias -> cuda:0
module.transformer.blocks.0.ffn.linear_second.weight -> cuda:0
module.transformer.blocks.0.ffn.linear_second.bias -> cuda:0
module.transformer.blocks.0.ffn_normalization.weight -> cuda:0
module.transformer.blocks.0.ffn_normalization.bias -> cuda:0
module.transformer.blocks.0.key_compression.weight -> cuda:0
module.transformer.blocks.0.value_compression.weight -> cuda:0
module.transformer.head.normalization.weight -> cuda:0
module.transformer.head.normalization.bias -> cuda:0
module.transformer.head.linear.weight -> cuda:0
module.transformer.head.linear.bias -> cuda:0
Batch shape is torch.Size([512, 152]) & device is cuda:0
Labels shape is torch.Size([512]) & device is cuda:0
First forward step --> shape is torch.Size([256, 152, 16]) & device is cuda:1
First forward step --> shape is torch.Size([256, 152, 16]) & device is cuda:0
Second forward step --> shape is torch.Size([256, 153, 16]) & device is cuda:1
Second forward step --> shape is torch.Size([256, 153, 16]) & device is cuda:0
Third forward step --> shape is torch.Size([256, 228]) & device is cuda:0
Traceback (most recent call last):
  File "/home/my_user_name/my_framework_name/minimal_reproducible_example.py", line 98, in <module>
    outputs = model(x_num=batch, x_cat=None)  # Forward pass to get output/logits
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/my_framework_name/minimal_reproducible_example.py", line 57, in forward
    x = self.transformer(x)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/rtdl/modules.py", line 1150, in forward
    x_residual, _ = layer['attention'](
                    ^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/rtdl/modules.py", line 893, in forward
    k = key_compression(k.transpose(1, 2)).transpose(1, 2)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/my_user_name/.conda/envs/my_environment_name/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: size mismatch, got 4096, 4096x153,0

Some reminders:

  • This error only happens when I'm training one several (two) GPUs. It works fine with only one.
  • This error only happens when I have a certain amount of input_features (here, 152). My global dataset has 10 633 input features, and I never had a problem training this model. During my ablation studies, I also tried training the model with 10 481 input features (it worked), with 10 483 input features (it worked), with 10 631 input features (it worked), with 2 input features (it failed) and with 150 input features (it failed). Why? Because on my 10 633 input features, the 10 481 first are forming a first group of features, the 2 next are forming a second group of features and the 150 last are forming a third group of features, so I'm trying every combination possible.

Thanks for your help!

@Yura52
Copy link
Collaborator

Yura52 commented Sep 27, 2023

@CesarLeblanc

Hi! Thank you for such a detailed report.

What I see from the error traceback:

  1. the error happens when key_compression is called
  2. key_compression is a linear layer
  3. the problem is that the output size of key_compression is zero

This happened because of the too low value of kv_compression_ratio=0.004 for this number of features. In a nutshell, key_compression is created as follows:

key_compression = nn.Linear(
    num_input_features,
    int(kv_compression_ratio * num_input_features),  # int(0.004 * 152) = 0
    bias=False,
)

The purpose of this is to reduce the number of features to make the attention faster.
However, there is no point in reducing the number of features below a certain threshold. This threshold is purely heuristic and should be chosen based on your budget and the downstream performance. In particular, it depends on the number of features.

The best scenario is when you don't need compression (i.e. kv_compression_ratio=None). If this does not fit into a budget, then choose kv_compression_ratio from the arange [a, b] based on your intuition and preference, where a is the smallest value that still provides good performance in terms of metrics, and b is the largest value that still fits into you budget.

Does this help?

@CesarLeblanc
Copy link
Author

Dear @Yura52,

Thanks a lot for your answer, it does help!
I had to put such a low value for the kv_compression_ratio parameter when the training was done on the full dataset (because of my available hardware and time budget), but it's true that it doesn't make a lot of sense to use the same value for this parameter when the first group of features (representing almost 99% of all of the input features) is removed during an ablation study.
So for the moment I simply changed my initialization of the model by modifying three lines:

class FTT(rtdl.FTTransformer):
    def __init__(self, n_num_features=None, cat_cardinalities=None, d_token=16, n_blocks=1, attention_n_heads=4, attention_dropout=0.3, attention_initialization='kaiming', attention_normalization='LayerNorm', ffn_d_hidden=16, ffn_dropout=0.1, ffn_activation='ReGLU', ffn_normalization='LayerNorm', residual_dropout=0.0, prenormalization=True, first_prenormalization=False, last_layer_query_idx=[-1], n_tokens=None, kv_compression_ratio=0.004, kv_compression_sharing='headwise', head_activation='ReLU', head_normalization='LayerNorm', d_out=None):
        feature_tokenizer = rtdl.FeatureTokenizer( 
            n_num_features=n_num_features,
            cat_cardinalities=cat_cardinalities,
            d_token=d_token
        )
        transformer = rtdl.Transformer(
            d_token=d_token,
            n_blocks=n_blocks,
            attention_n_heads=attention_n_heads,
            attention_dropout=attention_dropout,
            attention_initialization=attention_initialization,
            attention_normalization=attention_normalization,
            ffn_d_hidden=ffn_d_hidden,
            ffn_dropout=ffn_dropout,
            ffn_activation=ffn_activation,
            ffn_normalization=ffn_normalization,
            residual_dropout=residual_dropout,
            prenormalization=prenormalization,
            first_prenormalization=first_prenormalization,
            last_layer_query_idx=last_layer_query_idx,
            n_tokens=None if int(kv_compression_ratio * n_num_features) == 0 else n_num_features + 1,  # Modified line
            kv_compression_ratio=None if int(kv_compression_ratio * n_num_features) == 0 else kv_compression_ratio,  # Modified line
            kv_compression_sharing=None if int(kv_compression_ratio * n_num_features) == 0 else "headwise",  # Modified line
            head_activation=head_activation,
            head_normalization=head_normalization,
            d_out=d_out
        )
        super(FTT, self).__init__(feature_tokenizer, transformer)

It's clearly not optimal and can be improved (e.g., by automatically setting the value of kv_compression_ratio with respect to the number of input features and not putting always 0.004), but for the moment it is sufficient as the code is running.

However, I'm having trouble understanding why the code was working when I was running it on a single GPU (1x RTX 2080 Ti) but not on two (2x RTX 2080 Ti). Could you explain this? I understand your answer, but I don't get why the code still runs locally with 152 input features and a kv_compression_ratio of 0.004 (by "locally" I mean using my own computer that has one GPU).

@Yura52
Copy link
Collaborator

Yura52 commented Sep 28, 2023

I should admit I don't have a good explanation for that :)

@Yura52
Copy link
Collaborator

Yura52 commented Oct 2, 2023

Feel free to reopen the issue if needed!

@Yura52 Yura52 closed this as completed Oct 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants