Skip to content

Commit

Permalink
Merge pull request microsoft#19 from ROCmSoftwarePlatform/IFU-master-…
Browse files Browse the repository at this point in the history
…2021-06-07

IFU-master-2021-06-07
  • Loading branch information
jithunnair-amd committed Jun 30, 2021
2 parents 5b0fac7 + dad3b5f commit d98da5c
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 82 deletions.
109 changes: 54 additions & 55 deletions deepspeed/ops/sparse_attention/matmul.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _sdd_matmul(a,
'TN': block * pack,
'TMN': block * block * pack * pack,
'BLOCK': block,
'TK': TK,
'TK': 32,
'TYPE': dtype,
'STRIDE_AM': '1' if trans_a else 'lda',
'STRIDE_AK': 'lda' if trans_a else '1',
Expand All @@ -185,11 +185,11 @@ def _sdd_matmul(a,
'TZ': 1,
'NAME': 'sdd_kernel'
}
_sparse_matmul.sdd_cache[key] = triton.kernel(matmul,
defines=defines,
num_warps=[1,
2,
4])
_sparse_matmul.sdd_cache[key] = triton.kernel(
matmul,
defines=defines,
device=torch.device('cuda'),
num_warps=4)
#_sparse_matmul.sdd_cache[key] = triton.kernel(src, defines=defines, num_warps=[1, 2, 4])

kernel = _sparse_matmul.sdd_cache[key]
Expand All @@ -201,31 +201,30 @@ def _sdd_matmul(a,
max_width = 49152
total = 0 if bench else None
for off_width in range(0, width, max_width):
current = kernel(a,
b,
c,
a.stride(2),
b.stride(2),
block,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
b.stride(1),
c.stride(0),
AS2,
AS2,
AS3,
off_width,
lut,
locks,
num_lock,
grid=lambda opt:
[opt.d('TZ'),
min(max_width,
width - off_width),
AS0],
bench=bench)
current = kernel(
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.stride(2),
b.stride(2),
block,
a.stride(0),
b.stride(0),
c.stride(0),
a.stride(1),
b.stride(1),
c.stride(0),
AS2,
AS2,
AS3,
off_width,
lut.data_ptr(),
locks.data_ptr(),
num_lock,
grid=lambda opt: [opt.TZ,
min(max_width,
width - off_width),
AS0])
total = total + current if bench else None
time[0] = total
# save for backward pass
Expand Down Expand Up @@ -361,10 +360,10 @@ def _dds_matmul(a,
# kernel
key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _sparse_matmul.dds_cache:
TM = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
#TM = (64, 128) if dtype == torch.float32 else (64, 128, 256)
TK = 8 if dtype == torch.float32 else 16
defines = {
'TM': TM,
'TM': 128,
'TN': block,
'TK': TK,
'BLOCK': block,
Expand All @@ -380,7 +379,8 @@ def _dds_matmul(a,
}
_sparse_matmul.dds_cache[key] = triton.kernel(matmul,
defines=defines,
num_warps=[4])
device=torch.device('cuda'),
num_warps=4)
#_sparse_matmul.dds_cache[key] = triton.kernel(src, defines=defines, num_warps=[4])
kernel = _sparse_matmul.dds_cache[key]
# output
Expand All @@ -390,9 +390,9 @@ def _dds_matmul(a,
CS3 = AS2 if trans_c else BS2
locks = _sparse_matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
time[0] = kernel(a,
b,
c,
time[0] = kernel(a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.stride(2),
block,
c.stride(2),
Expand All @@ -406,14 +406,13 @@ def _dds_matmul(a,
BS2,
0,
0,
lut,
locks,
lut.data_ptr(),
locks.data_ptr(),
num_locks,
grid=lambda opt: [width,
triton.cdiv(AS2,
opt.d('TM')),
AS0],
bench=bench)
opt.TM),
AS0])
return c

@staticmethod
Expand Down Expand Up @@ -446,11 +445,11 @@ def _dsd_matmul(a,
# kernel
key = (block, a.dtype, b.dtype, trans_a, trans_b, trans_c)
if key not in _sparse_matmul.dsd_cache:
TN = [64, 128] if dtype == torch.float32 else [64, 128, 256]
TK = [8] if dtype == torch.float32 else [16]
#TN = (64, 128) if dtype == torch.float32 else (64, 128, 256)
TK = 8 if dtype == torch.float32 else 16
defines = {
'TM': block,
'TN': TN,
'TN': 128,
'TK': TK,
'BLOCK': block,
'TYPE': dtype,
Expand All @@ -465,7 +464,8 @@ def _dsd_matmul(a,
}
_sparse_matmul.dsd_cache[key] = triton.kernel(matmul,
defines=defines,
num_warps=[4])
device=torch.device('cuda'),
num_warps=4)
#_sparse_matmul.dsd_cache[key] = triton.kernel(src, defines=defines, num_warps=[4])
kernel = _sparse_matmul.dsd_cache[key]
# output
Expand All @@ -475,9 +475,9 @@ def _dsd_matmul(a,
CS3 = AS1 if trans_c else BS3
locks = _sparse_matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device)
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
time[0] = kernel(a,
b,
c,
time[0] = kernel(a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
block,
b.stride(2),
c.stride(2),
Expand All @@ -491,14 +491,13 @@ def _dsd_matmul(a,
AS1,
0,
0,
lut,
locks,
lut.data_ptr(),
locks.data_ptr(),
num_locks,
grid=lambda opt: [width,
triton.cdiv(BS3,
opt.d('TN')),
BS0],
bench=bench)
opt.TN),
BS0])
return c

fn = {
Expand Down
29 changes: 20 additions & 9 deletions deepspeed/ops/sparse_attention/softmax.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def make_kernel(cache,
attn_mask_mode)
if key not in cache:
defines = {
'TM': [1],
'TN': [TN],
'TM': 1,
'TN': TN,
'TYPE': dtype,
'BLOCK': block,
'INFINITY': {
Expand All @@ -96,7 +96,10 @@ def make_kernel(cache,
defines['APPLY_ATTN_MASK'] = True
if attn_mask_mode == 'mul':
defines['ATTN_MASK_MUL'] = True
kernel = triton.kernel(src, defines=defines, num_warps=[num_warps])
kernel = triton.kernel(src,
defines=defines,
device=torch.device('cuda'),
num_warps=num_warps)
cache[key] = kernel
return cache[key]

Expand Down Expand Up @@ -162,15 +165,16 @@ def forward(ctx,
kp_mask_mode,
attn_mask_mode)
M = x.shape[0]
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.d('TM')), M]
grid = lambda opt: [triton.cdiv(spdims[0] * spdims[1] * block, opt.TM), M]

# run kernel
time[0] = kernel(x, scale, lut, rpe, key_padding_mask, attn_mask,\
time[0] = kernel(x.data_ptr(), scale, lut.data_ptr(), rpe.data_ptr(), key_padding_mask.data_ptr(), attn_mask.data_ptr(),\
num_blocks, maxlut,\
x.stride(0),\
stride_zrpe, stride_hrpe, stride_srpe,\
stride_zrpe, stride_hrpe,\
stride_srpe,\
stride_zkpm, stride_zattnm,\
grid=grid, bench=bench)
grid=grid)
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
Expand Down Expand Up @@ -209,10 +213,17 @@ def backward(ctx, dx):
M = x.shape[0]
grid = lambda opt: [
triton.cdiv(ctx.spdims[0] * ctx.spdims[1] * ctx.block,
opt.d('TM')),
opt.TM),
M
]
kernel(x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), grid=grid)
kernel(x.data_ptr(),
ctx.scale,
dx.data_ptr(),
lut.data_ptr(),
ctx.maxlut,
x.stride(0),
dx.stride(0),
grid=grid)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None


Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def _set_batch_related_parameters(self):
#either none of the three parameters are provided or just gradient_accumulation_step is provided
else:
assert False, \
'Either train_batch_size or micro_batch_per_gpu needs to be provided'
'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'

def _configure_train_batch_size(self):
self._set_batch_related_parameters()
Expand Down
9 changes: 7 additions & 2 deletions op_builder/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def include_paths(self):
return ['csrc/aio/py_lib', 'csrc/aio/common']

def cxx_args(self):
return [
args = [
'-g',
'-Wall',
'-O0',
Expand All @@ -41,9 +41,14 @@ def cxx_args(self):
'-march=native',
'-fopenmp',
'-laio',
self.simd_width()
]

simd_width = self.simd_width()
if len(simd_width) > 0:
args.append(simd_width)

return args

def extra_ldflags(self):
return ['-laio']

Expand Down
16 changes: 7 additions & 9 deletions op_builder/sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,20 @@ def cxx_args(self):

def is_compatible(self):
# Check to see if llvm and cmake are installed since they are dependencies
required_commands = ['llvm-config|llvm-config-9', 'cmake']
command_status = list(map(self.command_exists, required_commands))
deps_compatible = all(command_status)
#required_commands = ['llvm-config|llvm-config-9', 'cmake']
#command_status = list(map(self.command_exists, required_commands))
#deps_compatible = all(command_status)

# torch-cpu will not have a cuda version
if torch.version.cuda is None:
cuda_compatible = False
self.warning(f"{self.NAME} cuda is not available from torch")
else:
major, minor = torch.version.cuda.split('.')[:2]
cuda_compatible = int(major) == 10 and int(minor) >= 1
cuda_compatible = (int(major) == 10
and int(minor) >= 1) or (int(major) >= 11)
if not cuda_compatible:
self.warning(
f"{self.NAME} requires CUDA version 10.1+, does not currently support >=11 or <10.1"
)
self.warning(f"{self.NAME} requires CUDA version 10.1+")

TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
Expand All @@ -48,5 +47,4 @@ def is_compatible(self):
f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
)

return super().is_compatible(
) and deps_compatible and torch_compatible and cuda_compatible
return super().is_compatible() and torch_compatible and cuda_compatible
2 changes: 1 addition & 1 deletion requirements/requirements-sparse_attn.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1 @@
triton==0.2.3
triton
13 changes: 8 additions & 5 deletions tests/unit/test_sparse_attention.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def sparse_to_dense(w, mask, block, zero=0):

def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype]
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)


Expand Down Expand Up @@ -189,6 +189,7 @@ def run_softmax_reference(x, scale, dx, kp_mask, attn_mask, layout, block):
def run_softmax_sparse(x, scale, dx, kp_mask, attn_mask, layout, block):
from deepspeed.ops.sparse_attention import Softmax
sparse_softmax = Softmax(layout, block, bench=False)

dx = dense_to_sparse(dx, layout, block)
x = dense_to_sparse(x, layout, block)
x.retain_grad()
Expand Down Expand Up @@ -239,13 +240,14 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo

def _skip_on_cuda_compatability():
#pytest.skip("Skip these tests for now until we get our docker image fixed.")
if torch.cuda.get_device_capability()[0] != 7:
pytest.skip("needs compute capability 7; v100")
if torch.cuda.get_device_capability()[0] >= 7:
pytest.skip("needs higher compute capability than 7")
cuda_major = int(torch.version.cuda.split('.')[0]) * 10
cuda_minor = int(torch.version.cuda.split('.')[1])
cuda_version = cuda_major + cuda_minor
if cuda_version != 101 and cuda_version != 102:
pytest.skip("requires cuda 10.1 or 10.2")
if (cuda_version != 101 and cuda_version != 102) and \
(cuda_version != 111 and cuda_version != 110):
pytest.skip("requires cuda 10.1 or 10.2 or 11.0 or 11.1")


@pytest.mark.parametrize("block", [16, 32])
Expand All @@ -261,6 +263,7 @@ def test_softmax(block, width, dtype):
layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask = init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, layout=None)
ref_y, ref_dx = run_softmax_reference(x, scale, dx, kp_mask, bool_attn_mask, layout, block)
st_y, st_dx = run_softmax_sparse(x, scale, dx, kp_mask, fp_attn_mask, layout, block)

assert allclose(ref_y, st_y)
assert allclose(ref_dx, st_dx)

Expand Down

0 comments on commit d98da5c

Please sign in to comment.