Skip to content

Commit

Permalink
Enable global weight decay to TBE (Backend) (#2516)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2516

With existing implementation for sparse embedding tables with rowwise adagrad, weight decay is performed to update the weights only when an ID and its corresponding embedding row appears within a training batch. This means that rows that do not show up won't be updated nor decayed, and hence the embedding table only gets *local* but not *global* weight decay.

This diff provides option to compensate for weight decay by scaling weight with `global weight decay` value using the formula from csmiler below:
```
global_weight_decay = (1 - learning_rate * weight_decay)^(current_iter - prev_iter - 1)
```
where `prev_iter` is the last iteration this ID (and its corresponding embedding row shows up.
 ---
**Usage:**
set
```
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
weight_decay_mode = WeightDecayMode.DECOUPLE_GLOBAL
```

e.g.,
```
tbe = SplitTableBatchedEmbeddingBagsCodegen(
            embedding_specs=[
                (E, D, managed_option, ComputeDevice.CUDA) for (E, D) in zip(Es, Ds)
            ],
            optimizer=OptimType.EXACT_ROWWISE_ADAGRAD,
            learning_rate=0.1,
            eps=0.1,
            output_dtype=output_dtype,
            pooling_mode=pooling_mode,
            weight_decay=0.01,
            weight_decay_mode=WeightDecayMode.DECOUPLE_GLOBAL,
        )
```
Relevant diffs:
D53866750
D55660277
D55660762

Reviewed By: sryap

Differential Revision: D56285676

fbshipit-source-id: 5a2c95aaf366b0893c16f4780edb607e96b7dad0
  • Loading branch information
spcyppt authored and facebook-github-bot committed May 3, 2024
1 parent d641676 commit c1f7a66
Show file tree
Hide file tree
Showing 15 changed files with 923 additions and 351 deletions.
29 changes: 28 additions & 1 deletion fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ set(VBE_OPTIMIZERS
rowwise_adagrad_with_counter
sgd)

# Optimizers with the GWD support
set(GWD_OPTIMIZERS
rowwise_adagrad)

# Individual optimizers (not fused with SplitTBE backward)
set(DEFUSED_OPTIMIZERS
rowwise_adagrad)
Expand Down Expand Up @@ -156,7 +160,10 @@ set(gen_gpu_kernel_source_files
if(NOT USE_ROCM)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_forward_split_weighted_v2_kernel.cu"
"gen_embedding_forward_split_unweighted_v2_kernel.cu")
"gen_embedding_forward_split_unweighted_v2_kernel.cu"
"gen_embedding_forward_split_weighted_gwd_codegen_cuda.cu"
"gen_embedding_forward_split_unweighted_gwd_codegen_cuda.cu"
)
endif()

foreach(wdesc dense split)
Expand Down Expand Up @@ -187,6 +194,14 @@ foreach(wdesc weighted unweighted)
"gen_embedding_backward_${wdesc}_vbe_split_device_kernel.cuh")
endforeach()

# Generate GWD files
if(NOT USE_ROCM)
foreach(wdesc weighted unweighted)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_forward_split_${wdesc}_gwd_kernel.cu")
endforeach()
endif()

set(gen_cpu_source_files
"gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp"
"gen_embedding_forward_quantized_weighted_codegen_cpu.cpp"
Expand Down Expand Up @@ -253,6 +268,18 @@ foreach(optimizer ${VBE_OPTIMIZERS})
endforeach()
endforeach()

if(NOT USE_ROCM)
foreach(optimizer ${GWD_OPTIMIZERS})
# GWD is not supported in nobag
foreach(wdesc weighted unweighted)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_backward_${optimizer}_split_${wdesc}_gwd_cuda.cu"
"gen_embedding_backward_${optimizer}_split_${wdesc}_gwd_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_split_${wdesc}_gwd_kernel_warp.cu")
endforeach()
endforeach()
endif()

foreach(optimizer ${DEFUSED_OPTIMIZERS})
list(APPEND gen_defused_optim_source_files
"gen_embedding_optimizer_${optimizer}_split.cpp"
Expand Down
31 changes: 29 additions & 2 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ def render_backward_templates(
optimizer: str,
filename_format: str,
kwargs: Dict[str, Any],
is_gwd: bool = False,
) -> None:
if not kwargs.get("has_gpu_support"):
return
vbe_options = [True, False] if kwargs.get("has_vbe_support") else [False]
vbe_options = (
[True, False] if (kwargs.get("has_vbe_support") and not is_gwd) else [False]
)
template = CodeTemplate.load(template_filepath)

for weighted in [True, False]:
for nobag in [True, False]:
for nobag in [True, False] if (not is_gwd) else [False]:
for vbe in vbe_options:
if (not nobag or (not weighted and not vbe)) and (
not kwargs.get("dense") or not vbe
Expand All @@ -56,6 +59,7 @@ def render_backward_templates(
is_index_select=False,
kdesc=wdesc,
**kwargs,
is_gwd=is_gwd,
)

@staticmethod
Expand Down Expand Up @@ -90,6 +94,29 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
filename_format,
kwargs,
)
# Generate the global weight decay CUDA kernels
if kwargs.get("has_global_weight_decay_support") and not args.is_rocm:
for template_filepath, filename_format in [
(
"training/backward/embedding_backward_split_kernel_cta_template.cu",
"gen_embedding_backward_{}_split_{}_gwd_kernel_cta.cu",
),
(
"training/backward/embedding_backward_split_kernel_warp_template.cu",
"gen_embedding_backward_{}_split_{}_gwd_kernel_warp.cu",
),
(
"training/backward/embedding_backward_split_template.cu",
"gen_embedding_backward_{}_split_{}_gwd_cuda.cu",
),
]:
BackwardSplitGenerator.render_backward_templates(
template_filepath,
optimizer,
filename_format,
kwargs,
is_gwd=True,
)

# Generate optimizer kernel
CodeTemplate.load(
Expand Down
27 changes: 27 additions & 0 deletions fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
# pyre-strict
# flake8: noqa F401

import argparse
import sys
from typing import List

try:
from .common import CodeTemplate
from .scripts_argsparse import args
except ImportError:
# pyre-ignore[21]
from common import CodeTemplate

# pyre-ignore[21]
from scripts_argsparse import args


class ForwardSplitGenerator:
@staticmethod
Expand All @@ -26,6 +31,7 @@ def render_forward_templates(
dense_options: List[bool],
nobag_options: List[bool],
vbe_options: List[bool],
is_gwd: bool = False,
) -> None:
template = CodeTemplate.load(template_filepath)
for dense in dense_options:
Expand All @@ -51,6 +57,7 @@ def render_forward_templates(
nobag=nobag,
vbe=vbe,
is_index_select=False,
is_gwd=is_gwd,
)

@staticmethod
Expand Down Expand Up @@ -98,6 +105,16 @@ def generate_kernels() -> None:
nobag_options=[False], # nobag is not used
vbe_options=[True, False],
)
# Generate the CUDA host code for global weight decay
if not args.is_rocm:
ForwardSplitGenerator.render_forward_templates(
"training/forward/embedding_forward_split_template.cu",
"gen_embedding_forward_{}_gwd_codegen_cuda.cu",
dense_options=[False],
nobag_options=[False], # nobag is not used
vbe_options=[False],
is_gwd=True,
)

# Generate the meta kernels
ForwardSplitGenerator.render_forward_templates(
Expand All @@ -116,6 +133,16 @@ def generate_kernels() -> None:
nobag_options=[True, False],
vbe_options=[True, False],
)
# Generate the global weight decay CUDA kernels
if not args.is_rocm:
ForwardSplitGenerator.render_forward_templates(
"training/forward/embedding_forward_split_kernel_template.cu",
"gen_embedding_forward_{}_gwd_kernel.cu",
dense_options=[False],
nobag_options=[False],
vbe_options=[False],
is_gwd=True,
)

# Generate the v2 CUDA kernels
ForwardSplitGenerator.render_forward_templates(
Expand Down
45 changes: 44 additions & 1 deletion fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,48 @@ def has_experimental_support(
return not dense and not nobag and not vbe and not is_index_select and not is_rocm


def is_valid_gwd_config(
dense: bool,
nobag: bool,
vbe: bool,
is_index_select: bool,
is_rocm: bool,
has_global_weight_decay_support: bool = True,
) -> bool:
"""
Check if the given combination of configs is valid for global weight decay support
- `has_global_weight_decay_support` is whether global weight decay is available for
an optimizer, but not all configs of such optimizer offer global weight decay support
- any updates to the configs need to be reflected in embedding_backward_split_host_template.cpp
- global weight decay does not support dense, nobag, vbe, is_index_select, and is_rocm
"""
return (
not dense
and not nobag
and not vbe
and not is_index_select
and not is_rocm
and has_global_weight_decay_support
)


def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str:
"""
For global weight decay kernel, compute the global weight decay value
and update prev_iter to be current iteration
This is to used in both warp and cta kernels.
"""
if is_global_weight_decay_kernel:
return """
const auto global_weight_decay = std::pow(weight_decay_base, iter - prev_iter_dev[linear_index] - 1);
if (threadIdx.x == 0) {
prev_iter_dev[linear_index] = iter;
}
"""
else:
return ""


################################################################################
# Register Helper Functions in Jinja Environment
################################################################################
Expand All @@ -311,7 +353,8 @@ def has_experimental_support(
env.globals["dispatch_vec_blocking_kernel"] = dispatch_vec_blocking_kernel
env.globals["is_valid_forward_config"] = is_valid_forward_config
env.globals["has_experimental_support"] = has_experimental_support

env.globals["is_valid_gwd_config"] = is_valid_gwd_config
env.globals["compute_global_weight_decay"] = compute_global_weight_decay

################################################################################
# Filter functions in Jinja Environment
Expand Down
21 changes: 19 additions & 2 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def dense() -> Dict[str, Any]:
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -84,6 +85,7 @@ def adagrad() -> Dict[str, Any]:
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -191,7 +193,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
if (weight_decay_mode == 1) {
// L2 regularization
correction = 1.0 - multiplier * weight_decay;
} else if (weight_decay_mode == 2) {
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
// Decoupled weight decay
correction = 1.0 - learning_rate * weight_decay;
} else {
Expand Down Expand Up @@ -221,7 +223,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
if (weight_decay_mode == 1) {
// L2 regularization
correction = 1.0 - multiplier * weight_decay;
} else if (weight_decay_mode == 2) {
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
// Decoupled weight decay
correction = 1.0 - learning_rate * weight_decay;
} else {
Expand Down Expand Up @@ -252,6 +254,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": True,
}


Expand Down Expand Up @@ -282,6 +285,7 @@ def approx_rowwise_adagrad() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": False,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -387,6 +391,7 @@ def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": False,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -422,6 +427,7 @@ def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": False,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -592,6 +598,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -640,6 +647,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": False,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -717,6 +725,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": False,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand All @@ -740,6 +749,7 @@ def sgd() -> Dict[str, Any]:
"has_cpu_support": True,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
}


Expand All @@ -763,6 +773,7 @@ def approx_sgd() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": False,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -840,6 +851,7 @@ def lamb() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -931,6 +943,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -986,6 +999,7 @@ def adam() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -1060,6 +1074,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand Down Expand Up @@ -1124,6 +1139,7 @@ def lars_sgd() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}


Expand All @@ -1141,4 +1157,5 @@ def none_optimizer() -> Dict[str, Any]:
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": False,
"has_global_weight_decay_support": False,
}
Loading

0 comments on commit c1f7a66

Please sign in to comment.