Skip to content

Commit

Permalink
multi-head self attention
Browse files Browse the repository at this point in the history
  • Loading branch information
undertherain committed Sep 16, 2020
1 parent 2beaa06 commit 3261bdf
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 0 deletions.
Empty file.
5 changes: 5 additions & 0 deletions benchmarker/modules/problems/attention/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from benchmarker.modules.problems.images_randomized import gen_data


def get_data(params):
return gen_data(params)
8 changes: 8 additions & 0 deletions benchmarker/modules/problems/attention/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import argparse


def set_extra_params(params, unparsed_args):
parser = argparse.ArgumentParser(description='Benchmark kernel')
parser.add_argument('--cnt_heads', type=int, default=8)
args = parser.parse_args(unparsed_args)
params["problem"].update(vars(args))
21 changes: 21 additions & 0 deletions benchmarker/modules/problems/attention/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch.nn as nn


class Net(nn.MultiheadAttention):
def forward(self, data):
super().forward(data, data, data)


def get_kernel(params):
assert params["mode"] == "inference"
# expected sizes: cnt_itmes, len_seq, dims
net = Net(embed_dim=params["problem"]["size"][2],
num_heads=params["problem"]["cnt_heads"],
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None)

return net
20 changes: 20 additions & 0 deletions test/pytorch/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import logging
import unittest

from benchmarker.benchmarker import run

logging.basicConfig(level=logging.DEBUG)


class PytorchLstmTest(unittest.TestCase):
def test_attention(self):
args = [
"--framework=pytorch",
"--problem=attention",
"--problem_size=2,2,4",
"--cnt_heads=2",
"--batch_size=1",
"--nb_epoch=1",
"--mode=inference",
]
run(args)

0 comments on commit 3261bdf

Please sign in to comment.