Skip to content

Commit

Permalink
When APEX AMP is available, functions for FP16 and FP32 processing by…
Browse files Browse the repository at this point in the history
… the recurrence kernel are registered and used, depending on the amp_recurrence_fp16 switch
  • Loading branch information
visionscaper committed Jan 31, 2021
1 parent 7833577 commit feedd77
Showing 1 changed file with 60 additions and 10 deletions.
70 changes: 60 additions & 10 deletions sru/ops.py
Expand Up @@ -18,6 +18,37 @@
)


def apex_amp_sru_compute_gpu_fp16(*args, **kwargs):
return _apex_amp_sru_compute_gpu(*args, **kwargs)


def apex_amp_sru_compute_gpu_fp32(*args, **kwargs):
return _apex_amp_sru_compute_gpu(*args, **kwargs)


def _apex_amp_sru_compute_gpu(*args, **kwargs):
# Will already have been imported and cached at this point
from .cuda_functional import SRU_Compute_GPU

return SRU_Compute_GPU.apply(*args, **kwargs)


try:
from apex import amp
APEX_AMP_AVAILABLE = True

import sys
current_module = sys.modules[__name__]

# TODO : remove debug statement
print("SRU: OPS: APEX AMP available, registering apex_amp_sru_compute_gpu for different Tensor precision types ...")

amp.register_half_function(current_module, "apex_amp_sru_compute_gpu_fp16") # Will cast input arguments to FP16
amp.register_float_function(current_module, "apex_amp_sru_compute_gpu_fp32") # Will cast input arguments to FP32
except ImportError:
APEX_AMP_AVAILABLE = False


@torch.jit.script
def elementwise_recurrence_cpu(U: Tensor,
x: Tensor,
Expand Down Expand Up @@ -94,21 +125,22 @@ def elementwise_recurrence_gpu(U: Tensor,
"""Elementwise forward operation of SRU on GPU.
"""
# Imported and cached her for the first time, retrieved from cache in _apex_amp_sru_compute_gpu
from .cuda_functional import SRU_Compute_GPU

cast = torch.Tensor.half if amp_recurrence_fp16 else torch.Tensor.float

U = cast(U)
x = cast(x)
weight_c = cast(weight_c)
bias = cast(bias)
c_init = cast(c_init)
scale_x = cast(scale_x) if scale_x is not None else scale_x
dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c

in_autocast = getattr(torch, 'is_autocast_enabled', lambda: False)()
if in_autocast:
with torch.cuda.amp.autocast(enabled=False):
cast = torch.Tensor.half if amp_recurrence_fp16 else torch.Tensor.float

U = cast(U)
x = cast(x)
weight_c = cast(weight_c)
bias = cast(bias)
c_init = cast(c_init)
scale_x = cast(scale_x) if scale_x is not None else scale_x
dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c

return SRU_Compute_GPU.apply(
U,
x,
Expand All @@ -123,6 +155,24 @@ def elementwise_recurrence_gpu(U: Tensor,
dropout_mask_c,
mask_pad
)
elif APEX_AMP_AVAILABLE:
apex_amp_sru_compute_gpu = \
apex_amp_sru_compute_gpu_fp16 if amp_recurrence_fp16 else apex_amp_sru_compute_gpu_fp32

return apex_amp_sru_compute_gpu(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)
else:
return SRU_Compute_GPU.apply(
U,
Expand Down

0 comments on commit feedd77

Please sign in to comment.