-
Notifications
You must be signed in to change notification settings - Fork 84
/
Bert.h
78 lines (62 loc) · 2.28 KB
/
Bert.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#ifndef CUBERT_BERT_H
#define CUBERT_BERT_H
#include <string>
#include <unordered_map>
#include "cuBERT.h"
#include "cuBERT/op_bert/BertEmbeddings.h"
#include "cuBERT/op_att/Transformer.h"
#include "cuBERT/op_bert/BertPooler.h"
#include "cuBERT/op_out/AdditionalOutputLayer.h"
namespace cuBERT {
template <typename T>
class Bert {
public:
explicit Bert(const std::unordered_map<std::string, T *> &var,
size_t max_batch_size,
size_t seq_length,
size_t vocab_size,
size_t type_vocab_size,
size_t hidden_size = 768,
size_t num_hidden_layers = 12,
size_t num_attention_heads = 12,
size_t intermediate_size = 3072,
size_t num_labels = 1);
virtual ~Bert();
// pre-compute buffers
void _pre_compute(size_t batch_size);
void compute(size_t batch_size, int *input_ids, int8_t *input_mask, int8_t *segment_ids);
// ouput methods, cpu/gpu outputs
void logits(size_t batch_size, T *logits, T *probs);
void pooled_output(size_t batch_size, T *pooled_output);
void sequence_output(size_t batch_size, T *sequence_output);
void embedding_output(size_t batch_size, T *embedding_output);
void output(size_t batch_size, cuBERT_Output* output);
// output is always float, convert half to float if necessary
void output_to_float(size_t batch_size, cuBERT_Output* output);
private:
void* cublas;
void* stream;
size_t max_batch_size;
size_t seq_length;
size_t hidden_size;
size_t num_labels;
BertEmbeddings<T> *bert_embeddings;
Transformer<T> *transformer;
Pooler<T> *bert_pooler;
ClassifierOutputLayer<T> *additional_output_layer;
// input buffer
int *input_ids_buf;
int8_t *input_mask_buf;
int8_t *segment_ids_buf;
// cpu/gpu output buffers
T *_embedding_output;
T *_sequence_output;
T *_pooled_output;
T *_logits;
T *_probs;
// for pre-compute
// FIXME: _sequence_output will be flushed
bool buffer_filled;
};
}
#endif //CUBERT_BERT_H