Skip to content

Commit

Permalink
Merge pull request facebookresearch#202 from fairinternal/pytorch_enc…
Browse files Browse the repository at this point in the history
…oder

[CI] Unit test vs. Pytorch Encoder and Decoder 1/2
  • Loading branch information
dianaml0 committed Jul 27, 2021
2 parents 281e734 + efb737f commit ce63eb3
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 41 deletions.
269 changes: 269 additions & 0 deletions tests/test_pytorch_transformer_parity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import random
import time

import pytest
import torch

from xformers.factory.model_factory import xFormer, xFormerConfig

BATCH = 20
SEQ = 64
EMB = 48
VOCAB = 16
HEADS = 4
DROP = 0.1
LAYERS = 2
ACTIVATION = "relu"

_devices = (
[torch.device("cpu")]
if not torch.cuda.is_available()
else [torch.device("cuda")] # save a bit on CI, we have seperate cpu and gpu jobs
)

_test_config_encoder = {
"block_config": {
"block_type": "encoder",
"dim_model": EMB,
"num_layers": LAYERS,
"layer_norm_style": "post",
"multi_head_config": {
"num_heads": HEADS,
"residual_dropout": DROP,
"use_separate_proj_weight": False,
"bias": True,
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
"seq_len": SEQ,
},
"dim_model": EMB,
},
"feedforward_config": {
"name": "MLP",
"dropout": DROP,
"activation": ACTIVATION,
"hidden_layer_multiplier": 4,
"dim_model": EMB,
},
},
}


_test_config_decoder = {
"block_config": {
"block_type": "decoder",
"dim_model": EMB,
"num_layers": LAYERS,
"layer_norm_style": "post",
"multi_head_config_masked": {
"num_heads": HEADS,
"residual_dropout": DROP,
"dim_model": EMB,
"use_separate_proj_weight": False,
"bias": True,
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
"seq_len": SEQ,
},
},
"multi_head_config_cross": {
"num_heads": HEADS,
"residual_dropout": DROP,
"dim_model": EMB,
"use_separate_proj_weight": False,
"bias": True,
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
"seq_len": SEQ,
},
},
"feedforward_config": {
"name": "MLP",
"dropout": DROP,
"activation": ACTIVATION,
"hidden_layer_multiplier": 4,
"dim_model": EMB,
},
}
}

_test_config = [_test_config_encoder, _test_config_decoder]


def _data(device):
# The dummy task is basically to classify sequences, either pure zeroes or some noise
input_a = torch.zeros((BATCH, SEQ, EMB), device=device)
input_b = (torch.rand((BATCH, SEQ, EMB), device=device) * VOCAB).abs()

target_a = torch.zeros((BATCH, SEQ), device=device)
target_b = torch.ones((BATCH, SEQ), device=device)

if random.random() > 0.5:
return torch.cat([input_a, input_b], dim=0), torch.cat(
[target_a, target_b], dim=0
)

return torch.cat([input_b, input_a], dim=0), torch.cat([target_b, target_a], dim=0)


def reset_seeds():
torch.manual_seed(0)
random.seed(0)


def step(model: torch.nn.Module, optim: torch.optim.Optimizer, device):
model.train()
optim.zero_grad()
batch, target = _data(device)

try:
outputs = model(batch)
except TypeError:
# Pytorch decoder exposes target explicitly
outputs = model(batch, tgt=batch)

loss = torch.norm(torch.mean(outputs, dim=-1) - target)
loss.backward()

# Clip grad and error out if we're producing NaNs, part of the unit test
torch.nn.utils.clip_grad_norm_(
model.parameters(), 10.0, norm_type=2.0, error_if_nonfinite=True
)
optim.step()

return loss.item()


def evaluate(model: torch.nn.Module, device):
batch, target = _data(device)
model.eval()
try:
outputs = model(batch)
except TypeError:
# Pytorch decoder exposes target explicitly
outputs = model(batch, tgt=batch)

return torch.norm(torch.mean(outputs, dim=-1) - target).item()


def train(model, optimizer, name, steps, device):
# Dummy training, just checking that both options give the same results
# Same seed for everyone
reset_seeds()
start = time.time()
for i in range(steps):
loss = step(model, optimizer, device)
print(i, name, loss)

print("Trained {} in {:.3}s".format(name, time.time() - start))


@pytest.mark.parametrize("device", _devices)
def test_pytorch_encoder_parity(device):
# Build both a xFormers and Pytorch model
reset_seeds()
model_xformers = xFormer.from_config(xFormerConfig([_test_config_encoder])).to(
device
)
print(model_xformers)

model_pytorch = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(
d_model=EMB,
nhead=HEADS,
dim_feedforward=4 * EMB,
dropout=DROP,
activation=ACTIVATION,
layer_norm_eps=1e-05,
batch_first=True, # (batch, seq, feature)
device=device,
),
num_layers=LAYERS,
)
print(model_pytorch)

optim_xformers = torch.optim.SGD(model_xformers.parameters(), lr=1e-3, momentum=0.9)
optim_pytorch = torch.optim.SGD(model_pytorch.parameters(), lr=1e-3, momentum=0.9)

# Check that both models can be trained to comparable results
eval_start_xformer = evaluate(model_xformers, device)
eval_start_pytorch = evaluate(model_pytorch, device)
print("starting point: ", eval_start_pytorch, eval_start_xformer)
train(model_pytorch, optim_pytorch, "pytorch", 500, device)
train(model_xformers, optim_xformers, "xformers", 500, device)

# Check that we can classify this dummy example
# Arbitrary threshold
eval_stop_xformer = evaluate(model_xformers, device)
eval_stop_pytorch = evaluate(model_pytorch, device)
print("end point: ", eval_stop_pytorch, eval_stop_xformer)

fit_ratio_xformer = eval_start_xformer / eval_stop_xformer
fit_ratio_pytorch = eval_start_pytorch / eval_stop_pytorch

print(fit_ratio_pytorch, fit_ratio_xformer)

# Catch a broken training
assert fit_ratio_xformer > 60
assert fit_ratio_pytorch > 60

# Catch a significant difference in between the two
assert (
abs(eval_start_xformer - eval_start_pytorch) < 1e-1
) # initial eval is about 50, arbitrary limits
assert (
abs(eval_stop_xformer - eval_stop_pytorch) < 1e-1
) # final eval is about 0.74, arbitrary limits


@pytest.mark.parametrize("device", _devices)
def test_pytorch_tranformer_parity(device):
# Build both a xFormers and Pytorch model
reset_seeds()
model_xformers = xFormer.from_config(xFormerConfig(_test_config)).to(device)
print(model_xformers)

model_pytorch = torch.nn.Transformer(
d_model=EMB,
nhead=HEADS,
num_encoder_layers=LAYERS,
num_decoder_layers=LAYERS,
dim_feedforward=4 * EMB,
dropout=DROP,
activation=ACTIVATION,
layer_norm_eps=1e-05,
batch_first=True, # (batch, seq, feature)
device=device,
)
print(model_pytorch)

optim_xformers = torch.optim.SGD(model_xformers.parameters(), lr=1e-3, momentum=0.9)
optim_pytorch = torch.optim.SGD(model_pytorch.parameters(), lr=1e-3, momentum=0.9)

# Check that both models can be trained to comparable results
eval_start_xformer = evaluate(model_xformers, device)
eval_start_pytorch = evaluate(model_pytorch, device)
print("starting point: ", eval_start_pytorch, eval_start_xformer)
train(model_xformers, optim_xformers, "xformers", 100, device)
train(model_pytorch, optim_pytorch, "pytorch", 100, device)

# Check that we can classify this dummy example
# Arbitrary threshold
eval_stop_xformer = evaluate(model_xformers, device)
eval_stop_pytorch = evaluate(model_pytorch, device)
print("end point: ", eval_stop_pytorch, eval_stop_xformer)

fit_ratio_xformer = eval_start_xformer / eval_stop_xformer
fit_ratio_pytorch = eval_start_pytorch / eval_stop_pytorch

print(fit_ratio_pytorch, fit_ratio_xformer)

# FIXME: Should not have a discrenpancy here.
assert fit_ratio_xformer > 30
assert fit_ratio_pytorch > 30
13 changes: 5 additions & 8 deletions tests/test_reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def data():
)

def step(model: torch.nn.Module, optim: torch.optim.Optimizer):
optim.zero_grad()
batch, target = data()
model.train()
optim.zero_grad()

outputs = model(batch)
loss = torch.norm(torch.mean(outputs, dim=-1) - target)
Expand All @@ -167,6 +168,7 @@ def step(model: torch.nn.Module, optim: torch.optim.Optimizer):

def evaluate(model: torch.nn.Module):
batch, target = data()
model.eval()
outputs = model(batch)
return torch.norm(torch.mean(outputs, dim=-1) - target).item()

Expand All @@ -178,9 +180,6 @@ def evaluate(model: torch.nn.Module):
device
)

model_reversible.train()
model_non_reversible.train()

optim_rev = torch.optim.SGD(model_reversible.parameters(), lr=1e-3, momentum=0.9)
optim_non_rev = torch.optim.SGD(
model_non_reversible.parameters(), lr=1e-3, momentum=0.9
Expand All @@ -198,7 +197,5 @@ def evaluate(model: torch.nn.Module):
# Arbitrary threshold
eval_stop_rev = evaluate(model_reversible)
eval_stop_non_rev = evaluate(model_non_reversible)
assert eval_start_rev / eval_stop_rev > 50
assert (
eval_start_non_rev / eval_stop_non_rev > 3
) # for some reason the reversible layers train very well here (?)
assert eval_start_rev / eval_stop_rev > 3
assert eval_start_non_rev / eval_stop_non_rev > 3
28 changes: 20 additions & 8 deletions xformers/components/multi_head_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn

from xformers.components.attention import Attention # , build_attention
from xformers.components.attention import Attention


class InProjContainer(torch.nn.Module):
Expand All @@ -22,10 +22,19 @@ def __init__(
key_proj: Optional[nn.Module],
value_proj: Optional[nn.Module],
):

super().__init__()

self.query_proj = query_proj
self.key_proj = key_proj if key_proj is not None else query_proj
self.value_proj = value_proj if value_proj is not None else query_proj

# If no projection is passed for key and value, the projection from the Query (minus optional bias) is used
bias_free_query_proj = nn.Linear(
self.query_proj.in_features, self.query_proj.out_features, bias=False # type: ignore
)
bias_free_query_proj.weights = self.query_proj.weight

self.key_proj = key_proj if key_proj is not None else bias_free_query_proj
self.value_proj = value_proj if value_proj is not None else bias_free_query_proj

def forward(
self,
Expand All @@ -42,6 +51,7 @@ class MultiHeadDispatchConfig:
residual_dropout: float
num_heads: int
attention: Attention
bias: bool
dim_key: Optional[int]
dim_value: Optional[int]
in_proj_container: Optional[InProjContainer]
Expand Down Expand Up @@ -69,6 +79,7 @@ def __init__(
residual_dropout: float,
num_heads: int,
attention: Attention,
bias: bool = True,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
in_proj_container: Optional[InProjContainer] = None,
Expand All @@ -94,13 +105,16 @@ def __init__(
self.attention = attention

# key, query, value projections for all heads
# critical options are
# - are we sharing weights ?
# - are we adding biases, and if yes are they shared ?
if attention.requires_input_projection:
self.in_proj_container = (
in_proj_container
if in_proj_container
if in_proj_container is not None
else InProjContainer(
query_proj=nn.Linear(
dim_model, dim_key, bias=False
dim_model, dim_key, bias=bias
), # NOTE: optional bias ?
key_proj=nn.Linear(dim_model, dim_key, bias=False)
if use_separate_proj_weight
Expand All @@ -115,9 +129,7 @@ def __init__(
self.resid_drop = nn.Dropout(residual_dropout, inplace=False)

# Output projection
self.proj = (
out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=False)
)
self.proj = out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias)

def _check(self, t, name):
assert (
Expand Down

0 comments on commit ce63eb3

Please sign in to comment.