-
Notifications
You must be signed in to change notification settings - Fork 1
/
mistral_helper.py
151 lines (126 loc) · 4.49 KB
/
mistral_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import random
from typing import Any, Sequence, Tuple
def ids_tensor(shape, vocab_size, rng=None, name=None):
# Creates a random int32 tensor of the shape within the vocab size
import torch
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
def _prepare_config_and_inputs(
batch_size: int,
seq_length: int,
vocab_size: int,
type_sequence_label_size: int = 2,
type_vocab_size: int = 16,
num_labels: int = 3,
num_choices: int = 4,
use_input_mask: bool = False,
use_token_type_ids: bool = False,
use_labels: bool = False,
) -> Tuple[Any]:
import torch
input_ids = ids_tensor([batch_size, seq_length], vocab_size)
input_mask = None
if use_input_mask:
input_mask = torch.tril(torch.ones(batch_size, seq_length))
token_type_ids = None
if use_token_type_ids:
assert type_vocab_size > 0, "type_vocab_size is null"
token_type_ids = ids_tensor([batch_size, seq_length], type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if use_labels:
assert type_sequence_label_size > 0, "type_sequence_label_size is null"
assert num_labels > 0, "num_labels is null"
assert num_choices > 0, "num_choices is null"
sequence_labels = ids_tensor([batch_size], type_sequence_label_size)
token_labels = ids_tensor([batch_size, seq_length], num_labels)
choice_labels = ids_tensor([batch_size], num_choices)
return (
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
)
def get_mistral_model(
input_dims: Sequence[Tuple[int, int]] = ((13, 7), (14, 7), (15, 8)),
hidden_size=32,
num_hidden_layers=2,
vocab_size=99,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
num_key_value_heads=2,
sliding_window=4096,
_attn_implementation="eager", # needed value to remove graph breaks
with_mask: bool = True,
):
"""
Returns a model.
See `MistralConfig
<https://huggingface.co/docs/transformers/main/en/model_doc/mistral#transformers.MistralConfig>`_.
The parameters are chosen for a unit test configuration.
"""
import torch
from transformers import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralModel
config = MistralConfig(
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
max_position_embeddings=max_position_embeddings,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
sliding_window=sliding_window,
)
if _attn_implementation:
config._attn_implementation = _attn_implementation
def generate_example_inputs(batch: int, seq: int, vocab_size: int, with_mask: bool):
(
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = _prepare_config_and_inputs(
batch_size=batch,
seq_length=seq,
vocab_size=vocab_size,
use_input_mask=with_mask,
)
if with_mask:
return input_ids, input_mask
return (input_ids,)
if with_mask:
class MistralModelWrapper(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = MistralModel(config)
def forward(self, input_ids, attention_mask):
model_output = self.model(input_ids, attention_mask=attention_mask)
return model_output.to_tuple()
else:
class MistralModelWrapper(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.model = MistralModel(config)
def forward(self, input_ids):
model_output = self.model(input_ids)
return model_output.to_tuple()
example_args_collection = []
for b, s in input_dims:
example_args_collection.append(
generate_example_inputs(b, s, vocab_size, with_mask)
)
return MistralModelWrapper(config), example_args_collection