Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547236120
  • Loading branch information
yilei authored and tensorflower-gardener committed Jul 11, 2023
1 parent 0fb661e commit 356216e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 40 deletions.
4 changes: 3 additions & 1 deletion official/nlp/configs/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class ClsHeadConfig(base_config.Config):
@dataclasses.dataclass
class PretrainerConfig(base_config.Config):
"""Pretrainer configuration."""
encoder: encoders.EncoderConfig = encoders.EncoderConfig()
encoder: encoders.EncoderConfig = dataclasses.field(
default_factory=encoders.EncoderConfig
)
cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
mlm_activation: str = "gelu"
mlm_initializer_range: float = 0.02
Expand Down
8 changes: 6 additions & 2 deletions official/nlp/configs/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class ElectraPretrainerConfig(base_config.Config):
discriminator_loss_weight: float = 50.0
tie_embeddings: bool = True
disallow_correct: bool = False
generator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
discriminator_encoder: encoders.EncoderConfig = encoders.EncoderConfig()
generator_encoder: encoders.EncoderConfig = dataclasses.field(
default_factory=encoders.EncoderConfig
)
discriminator_encoder: encoders.EncoderConfig = dataclasses.field(
default_factory=encoders.EncoderConfig
)
cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
44 changes: 32 additions & 12 deletions official/nlp/configs/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,39 @@ class SparseMixerEncoderConfig(hyperparams.Config):
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
type: Optional[str] = "bert"
albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig()
bert_v2: BertEncoderConfig = BertEncoderConfig()
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
reuse: ReuseEncoderConfig = ReuseEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
query_bert: QueryBertConfig = QueryBertConfig()
fnet: FNetEncoderConfig = FNetEncoderConfig()
sparse_mixer: SparseMixerEncoderConfig = SparseMixerEncoderConfig()
albert: AlbertEncoderConfig = dataclasses.field(
default_factory=AlbertEncoderConfig
)
bert: BertEncoderConfig = dataclasses.field(default_factory=BertEncoderConfig)
bert_v2: BertEncoderConfig = dataclasses.field(
default_factory=BertEncoderConfig
)
bigbird: BigBirdEncoderConfig = dataclasses.field(
default_factory=BigBirdEncoderConfig
)
kernel: KernelEncoderConfig = dataclasses.field(
default_factory=KernelEncoderConfig
)
mobilebert: MobileBertEncoderConfig = dataclasses.field(
default_factory=MobileBertEncoderConfig
)
reuse: ReuseEncoderConfig = dataclasses.field(
default_factory=ReuseEncoderConfig
)
xlnet: XLNetEncoderConfig = dataclasses.field(
default_factory=XLNetEncoderConfig
)
query_bert: QueryBertConfig = dataclasses.field(
default_factory=QueryBertConfig
)
fnet: FNetEncoderConfig = dataclasses.field(default_factory=FNetEncoderConfig)
sparse_mixer: SparseMixerEncoderConfig = dataclasses.field(
default_factory=SparseMixerEncoderConfig
)
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = hyperparams.Config()
any: hyperparams.Config = dataclasses.field(
default_factory=hyperparams.Config
)


@gin.configurable
Expand Down
71 changes: 46 additions & 25 deletions official/projects/edgetpu/nlp/configs/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,31 @@ class OrbitParams(base_config.Config):
@dataclasses.dataclass
class OptimizerParams(optimization.OptimizationConfig):
"""Optimizer parameters for MobileBERT-EdgeTPU."""
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
type='adamw',
adamw=optimization.AdamWeightDecayConfig(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias']))
learning_rate: optimization.LrConfig = optimization.LrConfig(
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0))
warmup: optimization.WarmupConfig = optimization.WarmupConfig(
type='polynomial',
polynomial=optimization.PolynomialWarmupConfig(warmup_steps=10000))
optimizer: optimization.OptimizerConfig = dataclasses.field(
default_factory=lambda: optimization.OptimizerConfig( # pylint: disable=g-long-lambda
type='adamw',
adamw=optimization.AdamWeightDecayConfig(
weight_decay_rate=0.01,
exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'],
),
)
)
learning_rate: optimization.LrConfig = dataclasses.field(
default_factory=lambda: optimization.LrConfig( # pylint: disable=g-long-lambda
type='polynomial',
polynomial=optimization.PolynomialLrConfig(
initial_learning_rate=1e-4,
decay_steps=1000000,
end_learning_rate=0.0,
),
)
)
warmup: optimization.WarmupConfig = dataclasses.field(
default_factory=lambda: optimization.WarmupConfig( # pylint: disable=g-long-lambda
type='polynomial',
polynomial=optimization.PolynomialWarmupConfig(warmup_steps=10000),
)
)


@dataclasses.dataclass
Expand Down Expand Up @@ -144,16 +155,26 @@ class EdgeTPUBERTCustomParams(base_config.Config):
distill_ground_truth_ratio: A float number representing the ratio between
distillation output and ground truth.
"""
train_datasest: DatasetParams = DatasetParams()
eval_dataset: DatasetParams = DatasetParams()
teacher_model: Optional[PretrainerModelParams] = PretrainerModelParams()
student_model: PretrainerModelParams = PretrainerModelParams()
train_datasest: DatasetParams = dataclasses.field(
default_factory=DatasetParams
)
eval_dataset: DatasetParams = dataclasses.field(default_factory=DatasetParams)
teacher_model: Optional[PretrainerModelParams] = dataclasses.field(
default_factory=PretrainerModelParams
)
student_model: PretrainerModelParams = dataclasses.field(
default_factory=PretrainerModelParams
)
teacher_model_init_checkpoint: str = ''
student_model_init_checkpoint: str = ''
layer_wise_distillation: LayerWiseDistillationParams = (
LayerWiseDistillationParams())
end_to_end_distillation: EndToEndDistillationParams = (
EndToEndDistillationParams())
optimizer: OptimizerParams = OptimizerParams()
runtime: RuntimeParams = RuntimeParams()
orbit_config: OrbitParams = OrbitParams()
layer_wise_distillation: LayerWiseDistillationParams = dataclasses.field(
default_factory=LayerWiseDistillationParams
)
end_to_end_distillation: EndToEndDistillationParams = dataclasses.field(
default_factory=EndToEndDistillationParams
)
optimizer: OptimizerParams = dataclasses.field(
default_factory=OptimizerParams
)
runtime: RuntimeParams = dataclasses.field(default_factory=RuntimeParams)
orbit_config: OrbitParams = dataclasses.field(default_factory=OrbitParams)

0 comments on commit 356216e

Please sign in to comment.