Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDPA does not switch off dropout during evaluation #124464

Closed
Giventicket opened this issue Apr 19, 2024 · 11 comments
Closed

SDPA does not switch off dropout during evaluation #124464

Giventicket opened this issue Apr 19, 2024 · 11 comments
Labels
actionable module: docs Related to our documentation, both in docs/ and docblocks oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Giventicket
Copy link

Giventicket commented Apr 19, 2024

馃悰 Describe the bug

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True) # Bug Here!!
    return attn_weight @ value 

I find out the detailed pseudo-like code from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html while using flash attention.

However the metric measured from flash attention model is changed (especially when I use dropout_p = 0.1) which is supposed to be unchanged. Flash attention does not affect the output as I know. Then I figure out the expected bug from the code above and marked with annotation "Bug Here!!".

I execute the two code blocks following:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True

# Using functional dropout
class Model1(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, inputs):
        return F.dropout(inputs, p=self.p, training=True)

# Example usage
model1 = Model1(p=0.5)

inputs = torch.rand(10)
print("Random input", inputs)
print()

# Applying dropout in training mode
print('Training mode:')
print('Model 1', model1(inputs))

# Switching to evaluation mode
model1.eval()

# Applying dropout in evaluation mode
print('Evaluation mode:')
print('Model 1', model1(inputs))

and

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True

# Using nn.Dropout module
class Model2(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.drop_layer = nn.Dropout(p=p)

    def forward(self, inputs):
        return self.drop_layer(inputs)
# Example usage
model2 = Model2(p=0.5)

inputs = torch.rand(10)
print("Random input", inputs)
print()

# Applying dropout in training mode
print('Training mode:')
print('Model 2', model2(inputs))

# Switching to evaluation mode
model2.eval()

# Applying dropout in evaluation mode
print('Evaluation mode:')
print('Model 2', model2(inputs))

The result for them is following:

釀夅叧釀忈叧釀呩叺釂剦釁a喓 2024-04-19 釀嬦叐釀掅叜 5 04 13 釀夅叧釀忈叧釀呩叺釂剦釁a喓 2024-04-19 釀嬦叐釀掅叜 5 04 23

As you look at the model 1, the same process as dropout in scaled_dot_product_attention, it does not act correctly even though I turn on the evaluation option. So please fix this one as soon as possible. Thank you.

Versions

Collecting environment information...
PyTorch version: 2.3.0.dev20240121+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-159-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 535.154.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 80
On-line CPU(s) list: 0-79
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 20
Socket(s): 2
Stepping: 7
CPU max MHz: 4000.0000
CPU min MHz: 800.0000
BogoMIPS: 4200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 1.3 MiB (40 instances)
L1i cache: 1.3 MiB (40 instances)
L2 cache: 40 MiB (40 instances)
L3 cache: 55 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55,57,59,61,63,65,67,69,71,73,75,77,79
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] pytorch-lightning==2.1.2
[pip3] pytorch-triton==2.2.0+e28a256d71
[pip3] torch==2.3.0.dev20240121+cu121
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.2.0.dev20240121+cu121
[pip3] torchmetrics==1.3.0.post0
[pip3] torchvision==0.18.0.dev20240121+cu121
[conda] Could not collect

cc @svekars @brycebortree @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

@jbschlosser
Copy link
Contributor

jbschlosser commented Apr 19, 2024

Hey @Giventicket, in your first code block, you're setting the model to evaluation mode, but still passing F.dropout(..., training=True). Does this explain what you're seeing?

Is the purpose of your issue that this is also happening in the doc example? We'd accept a PR updating the doc to fix this.

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: transformer/mha Issues related to Transformers and MultiheadAttention module: docs Related to our documentation, both in docs/ and docblocks labels Apr 19, 2024
@Giventicket
Copy link
Author

Giventicket commented Apr 20, 2024

yes, still passing F.dropout(..., training=True) is the problem. Thanks for commenting and PR.
But not just about docs, the main problem is that module is malfunctioning when I turn on dropout option(feed dropout_p with any value over zero).

@Giventicket
Copy link
Author

@jbschlosser checkout plz

@jbschlosser
Copy link
Contributor

the main problem is that module is malfunctioning when I turn on dropout option(feed dropout_p with any value over zero).

@Giventicket Not sure I'm clear on exactly what the malfunction is, but I'll hazard a guess that you're trying to use SDPA from your module and it isn't respecting the module's evaluation status. Can you workaround this by manually providing dropout_p=0.0 when not in training mode?

class Model1(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, ...):
        return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))

If I'm wrong about the problem, is it possible to get more detail on the behavior you're looking for?

@Giventicket
Copy link
Author

@jbschlosser whoop that makes sense, thank a lot!

@anmorgunov
Copy link

but I'll hazard a guess that you're trying to use SDPA from your module and it isn't respecting the module's evaluation status. Can you workaround this by manually providing dropout_p=0.0 when not in training mode?

I stumbled upon this behavior as well.

Yes, it appears that SDPA doesn't respect .eval() and uses non-zero value of dropout_p during inference. I think what @Giventicket was trying to say is that in the pseudocode for SDPA there is a line:

attn_weight = torch.dropout(attn_weight, dropout_p, train=True)

which makes you think as if train is set to permanently as True, regardless of whether the model is in train or eval mode. If that's indeed how SDPA works in C, it is, at minimum, peculiar.

Doing this dropout_p=(self.p if self.training else 0.0) will obviously solve the problem, but I wonder why is it necessary? I can imagine PyTorch being able to detect whether the model is in eval or train mode and adjust dropout_p manually. Is it an intentional decision to delegate this responsibility to the user? If so, I'd be curious to hear the rationale, if you're not too busy to explain.

@jbschlosser
Copy link
Contributor

jbschlosser commented May 8, 2024

I can imagine PyTorch being able to detect whether the model is in eval or train mode and adjust dropout_p manually. Is it an intentional decision to delegate this responsibility to the user? If so, I'd be curious to hear the rationale, if you're not too busy to explain.

@anmorgunov Note that SDPA (AKA F.scaled_dot_product_attention()) as mentioned throughout here is a function and not a module, so it has no concept of a model being in train or eval mode. It does exactly what it says it will wrt its dropout_p argument. Per separation of concerns, it's up to the surrounding module that calls the SDPA function to handle train / eval status, as that's where that concept is introduced.

We could have a ScaledDotProductAttention module in torch.nn that wraps the functional SDPA and handles dropout details between train / eval. I think usage of this module within your model would match more closely to what you're expecting, as it would respect the parent module's train() / eval() status and automagically shut off the dropout for eval mode.

An alternative would be to bleed the module-level training concept down into the function:

... = F.scaled_dot_product_attention(..., dropout_p=0.4, train=False)  # if train is False, ignore passed-in dropout_p and use 0.0

@anmorgunov
Copy link

as mentioned throughout here is a function and not a module, so it has no concept of a model being in train or eval mode

once you spell it out, it makes perfect sense to behave it like this and in no other way

@jbschlosser jbschlosser changed the title Find a bug from beta-released "scaled_dot_product_attention" SDPA does not switch off dropout during training May 9, 2024
@soumith
Copy link
Member

soumith commented May 13, 2024

if we detect eval mode, and dropout_p ends up being non-zero, then maybe we print a warning (or maybe even an error?)? That seems like the first immediate step.

Yes, SDPA is a function, not a module, but having userland avoid footguns is our imperative and responsibility, so we gotta do something about it.

@jbschlosser
Copy link
Contributor

jbschlosser commented May 13, 2024

if we detect eval mode, and dropout_p ends up being non-zero, then maybe we print a warning (or maybe even an error?)? That seems like the first immediate step.

I don't think it's possible to detect eval mode of the module that made the function call from within SDPA without a signature change. We can detect inference_mode() though and warn / error in that case.

@jbschlosser
Copy link
Contributor

jbschlosser commented May 13, 2024

Reopening to resolve this. At a minimum, we should add a doc warning about this behavior.

@jbschlosser jbschlosser reopened this May 13, 2024
@jbschlosser jbschlosser changed the title SDPA does not switch off dropout during training SDPA does not switch off dropout during evaluation May 13, 2024
jbschlosser added a commit that referenced this issue May 15, 2024
Fixes #124464

TBD: validate formatting

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: docs Related to our documentation, both in docs/ and docblocks oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants