Skip to content

Commit

Permalink
Pass args around to cleanup parameter lists
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Jun 15, 2018
1 parent 559eca8 commit 1235aa0
Showing 1 changed file with 42 additions and 80 deletions.
122 changes: 42 additions & 80 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,44 +89,16 @@ def build_embedding(dictionary, embed_dim):
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(dst_dict, args.decoder_embed_dim)

encoder = TransformerEncoder(
src_dict,
encoder_embed_tokens,
ffn_inner_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
num_attn_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.encoder_learned_pos,
)
decoder = TransformerDecoder(
dst_dict,
decoder_embed_tokens,
ffn_inner_dim=args.decoder_ffn_embed_dim,
num_layers=args.decoder_layers,
num_attn_heads=args.decoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
relu_dropout=args.relu_dropout,
normalize_before=args.encoder_normalize_before,
learned_pos_embed=args.decoder_learned_pos,
share_input_output_embed=args.share_decoder_input_output_embed,
)

encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, dst_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder)


class TransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout = dropout
self.dropout = args.dropout

embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
Expand All @@ -136,17 +108,13 @@ def __init__(
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
learned=learned_pos_embed,
learned=args.encoder_learned_pos,
)

self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout,
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])

self.reset_parameters()
Expand Down Expand Up @@ -186,15 +154,10 @@ def max_positions(self):

class TransformerDecoder(FairseqDecoder):
"""Transformer decoder."""
def __init__(
self, dictionary, embed_tokens, ffn_inner_dim=2048,
num_layers=6, num_attn_heads=8, dropout=0.1, attention_dropout=0.,
relu_dropout=0., normalize_before=False, learned_pos_embed=False,
share_input_output_embed=False,
):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout = dropout
self.share_input_output_embed = share_input_output_embed
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed

embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
Expand All @@ -204,20 +167,16 @@ def __init__(
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
learned=learned_pos_embed,
learned=args.decoder_learned_pos,
)

self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(
embed_dim, ffn_inner_dim, num_attn_heads, dropout=dropout,
attention_dropout=attention_dropout, relu_dropout=relu_dropout,
normalize_before=normalize_before,
)
for i in range(num_layers)
TransformerDecoderLayer(args)
for i in range(args.decoder_layers)
])

if not share_input_output_embed:
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))

self.reset_parameters()
Expand Down Expand Up @@ -276,19 +235,19 @@ class TransformerEncoderLayer(nn.Module):
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
def __init__(self, args):
super().__init__()
self.embed_dim = embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.dropout = dropout
self.relu_dropout = relu_dropout
self.normalize_before = normalize_before
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim)
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(2)])
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = nn.Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])

def forward(self, x, encoder_padding_mask):
residual = x
Expand Down Expand Up @@ -318,20 +277,23 @@ def maybe_layer_norm(self, i, x, before=False, after=False):

class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(
self, embed_dim, ffn_inner_dim, num_attn_heads, dropout=0.1,
attention_dropout=0., relu_dropout=0., normalize_before=False,
):
def __init__(self, args):
super().__init__()
self.embed_dim = embed_dim
self.self_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.dropout = dropout
self.relu_dropout = relu_dropout
self.normalize_before = normalize_before
self.encoder_attn = MultiheadAttention(embed_dim, num_attn_heads, dropout=attention_dropout)
self.fc1 = nn.Linear(embed_dim, ffn_inner_dim)
self.fc2 = nn.Linear(ffn_inner_dim, embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(embed_dim) for i in range(3)])
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = nn.Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])

def forward(self, x, encoder_out, encoder_padding_mask):
residual = x
Expand Down

0 comments on commit 1235aa0

Please sign in to comment.