Skip to content

Commit

Permalink
Update on "fix nn.MHA + quantized scriptability"
Browse files Browse the repository at this point in the history
Fixes a post-1.8 regression in nn.MultiheadAttention + quantization scriptability introduced in #52537. Passes the new test introduced in that PR, and fixes the repro found by @ngimel [here](https://gist.github.com/bhosmer/ef517d0774f2f10336b8140116fd6b62). 

Per comments in #52537 there's definitely a carnal dependency between quantization and the `_LinearWithBias` class by name that I'm reinstating here, but there may be cleaner ways to solve this - I don't really know what I'm doing 😁 . 

@jbschlosser @z-a-f LMK if you have ideas, happy to change this as desired. It'd be nice to get a fix into 1.9.

_[Update: now using a better name instead of `_LinearWithBias`, but this remains a short-term fix to re-suppress a quantization API usage error that should properly be raised upstream. See issue #58969]_

Differential Revision: [D28593830](https://our.internmc.facebook.com/intern/diff/D28593830)

[ghstack-poisoned]
  • Loading branch information
bhosmer committed May 26, 2021
1 parent 1940055 commit c4e8358
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch/nn/quantizable/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, embed_dim: int, num_heads: int,
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]

# Functionals
self.q_scaling_product = nnq.FloatFunctional()
Expand Down Expand Up @@ -100,8 +100,8 @@ def from_float(cls, other):

# Set the linear weights
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
if other._qkv_same_embed_dim:
# Use separate params
bias = other.in_proj_bias
Expand Down Expand Up @@ -451,7 +451,7 @@ def _forward_impl(self,
# Reentering the quantized zone
attn_output = self.quant_attn_output(attn_output)
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)

if need_weights:
Expand Down

0 comments on commit c4e8358

Please sign in to comment.