Skip to content

Commit

Permalink
Add batchnorm option to hubert/wav2vec2 positional convolution layer …
Browse files Browse the repository at this point in the history
…for hubert bf16 models (#5285)

* add conv_batch_norm for hubert to support bf16

* linting

Co-authored-by: Bowen Shi <bshi@meta.com>
  • Loading branch information
tuanh208 and Bowen Shi committed Aug 18, 2023
1 parent 100cd91 commit 4db2649
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
6 changes: 6 additions & 0 deletions fairseq/models/hubert/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ class HubertConfig(FairseqDataclass):
default=16,
metadata={"help": "number of groups for convolutional positional embedding"},
)
conv_pos_batch_norm: bool = field(
default=False,
metadata={
"help": "use batch norm instead of weight norm in conv_pos (for bf16 models)"
},
)

latent_temp: Tuple[float, float, float] = field(
default=(2, 0.5, 0.999995),
Expand Down
19 changes: 13 additions & 6 deletions fairseq/models/wav2vec/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def forward(self, x):
return x


def make_conv_pos(e, k, g):
def make_conv_pos(e, k, g, is_batch_norm=False):
pos_conv = nn.Conv1d(
e,
e,
Expand All @@ -935,8 +935,12 @@ def make_conv_pos(e, k, g):
nn.init.normal_(pos_conv.weight, mean=0, std=std)
nn.init.constant_(pos_conv.bias, 0)

pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
if not is_batch_norm:
pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2)
pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU())
else:
batch_norm = nn.BatchNorm1d(e)
pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU())

return pos_conv

Expand Down Expand Up @@ -1047,6 +1051,9 @@ def make_conv_block(e, k, g, l):
self.embedding_dim,
args.conv_pos,
args.conv_pos_groups,
is_batch_norm=args.conv_pos_batch_norm
if hasattr(args, "conv_pos_batch_norm")
else False,
)

self.layers = nn.ModuleList(
Expand Down Expand Up @@ -1370,7 +1377,7 @@ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
and without using ModuleList orto speed up training throughput.
"""
super().__init__()

self.adapter_num = adapter_num
self.input_dim = input_dim
self.hidden_dim = hidden_dim
Expand Down Expand Up @@ -1405,7 +1412,7 @@ def reset_parameters(self):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.b_b[ii], -bound, bound)

nn.init.ones_(self.ln_W)
nn.init.zeros_(self.ln_b)

Expand All @@ -1418,7 +1425,7 @@ def forward(self, x, adapter_id):
h = F.linear(h, self.W_b[ii], self.b_b[ii])
outputs = h
return outputs

def extra_repr(self):
return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))

Expand Down

0 comments on commit 4db2649

Please sign in to comment.