-
Notifications
You must be signed in to change notification settings - Fork 348
/
Copy pathtest_chronos.py
388 lines (319 loc) · 12.5 KB
/
test_chronos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import pytest
import torch
from chronos import (
BaseChronosPipeline,
ChronosConfig,
ChronosPipeline,
MeanScaleUniformBins,
)
from test.util import validate_tensor
def test_base_chronos_pipeline_loads_from_huggingface():
BaseChronosPipeline.from_pretrained("amazon/chronos-t5-tiny", device_map="cpu")
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
n_tokens = n_numerical_tokens + n_special_tokens
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=True,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
assert isinstance(tokenizer, MeanScaleUniformBins)
context = tokenizer.centers.unsqueeze(0) # add batch dimension
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
token_ids, _, _ = tokenizer._input_transform(context, scale=scale)
samples = tokenizer.output_transform(
token_ids.unsqueeze(1), # add sample dimension
scale=scale,
)
assert (samples[0, 0, :] == context).all()
@pytest.mark.xfail
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
@pytest.mark.parametrize("use_eos_token", [False, True])
def test_tokenizer_fixed_data(
n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool
):
n_tokens = n_numerical_tokens + n_special_tokens
context_length = 3
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=use_eos_token,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
context = torch.tensor(
[
[-3.7, 3.7],
[-42.0, 42.0],
]
)
batch_size, _ = context.shape
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size))
assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size))
if use_eos_token:
assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size))
samples = tokenizer.output_transform(
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
tokenizer_state=scale,
)
assert (samples[:, 0, [0, -1]] == context).all()
@pytest.mark.xfail
@pytest.mark.parametrize("use_eos_token", [False, True])
def test_tokenizer_random_data(use_eos_token: bool):
context_length = 8
n_tokens = 256
n_special_tokens = 2
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=use_eos_token,
model_type="seq2seq",
context_length=context_length,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
context = torch.tensor(
[
[torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0],
[3.0, torch.nan, 3.9, 4.0, 4.1, 4.9],
]
)
token_ids, attention_mask, scale = tokenizer.context_input_transform(context)
assert token_ids.shape == (
*context.shape[:-1],
context_length + 1 * use_eos_token,
)
assert attention_mask.shape == (
*context.shape[:-1],
context_length + 1 * use_eos_token,
)
assert scale.shape == context.shape[:1]
sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4))
sample_ids[0, 0, 0] = n_special_tokens
sample_ids[-1, -1, -1] = n_tokens - 1
samples = tokenizer.output_transform(sample_ids, scale)
assert samples.shape == (2, 10, 4)
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=model_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
# input: tensor of shape (batch_size, context_length)
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
with pytest.raises(ValueError):
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=True
)
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_tensor(samples, shape=(4, 12, 3), dtype=torch.float32)
with pytest.raises(ValueError):
samples = pipeline.predict(
list(context),
num_samples=7,
prediction_length=65,
limit_prediction_length=True,
)
samples = pipeline.predict(
list(context),
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_tensor(samples, shape=(4, 7, 65), dtype=torch.float32)
# input: tensor of shape (context_length,)
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_tensor(samples, shape=(1, 12, 3), dtype=torch.float32)
with pytest.raises(ValueError):
samples = pipeline.predict(
context[0, ...],
num_samples=7,
prediction_length=65,
limit_prediction_length=True,
)
samples = pipeline.predict(
context[0, ...],
num_samples=7,
prediction_length=65,
)
validate_tensor(samples, shape=(1, 7, 65), dtype=torch.float32)
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
@pytest.mark.parametrize("prediction_length", [3, 65])
@pytest.mark.parametrize(
"quantile_levels", [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], [0.1, 0.5, 0.9]]
)
def test_pipeline_predict_quantiles(
model_dtype: torch.dtype,
input_dtype: torch.dtype,
prediction_length: int,
quantile_levels: list[int],
):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=model_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
num_expected_quantiles = len(quantile_levels)
# input: tensor of shape (batch_size, context_length)
quantiles, mean = pipeline.predict_quantiles(
context,
num_samples=12,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
quantiles, mean = pipeline.predict_quantiles(
list(context),
num_samples=12,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(
quantiles, (4, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (4, prediction_length), dtype=torch.float32)
# input: tensor of shape (context_length,)
quantiles, mean = pipeline.predict_quantiles(
context[0, ...],
num_samples=12,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
)
validate_tensor(
quantiles, (1, prediction_length, num_expected_quantiles), dtype=torch.float32
)
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=model_dtype,
)
d_model = pipeline.model.model.config.d_model
context = 10 * torch.rand(size=(4, 16)) + 10
context = context.to(dtype=input_dtype)
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
# input: tensor of shape (batch_size, context_length)
embedding, scale = pipeline.embed(context)
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(scale, shape=(4,), dtype=torch.float32)
# input: batch_size-long list of tensors of shape (context_length,)
embedding, scale = pipeline.embed(list(context))
validate_tensor(
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(scale, shape=(4,), dtype=torch.float32)
# input: tensor of shape (context_length,)
embedding, scale = pipeline.embed(context[0, ...])
validate_tensor(
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
)
validate_tensor(scale, shape=(1,), dtype=torch.float32)
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
def test_tokenizer_number_of_buckets(n_tokens):
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=2,
pad_token_id=0,
eos_token_id=1,
use_eos_token=True,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
n_numerical_tokens = config.n_tokens - config.n_special_tokens
# The tokenizer has one bucket too many as a result of an early bug. In order to
# keep consistent with the original trained models, this is kept as it is. However,
# token ids are clipped to a maximum of `n_tokens - 1` to avoid out-of-bounds errors.
assert len(tokenizer.centers) == (n_numerical_tokens - 1)
assert len(tokenizer.boundaries) == n_numerical_tokens
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
def test_token_clipping(n_tokens):
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs={"low_limit": -15, "high_limit": 15},
n_tokens=n_tokens,
n_special_tokens=2,
pad_token_id=0,
eos_token_id=1,
use_eos_token=True,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
huge_value = 1e22 # this large value is assigned to the largest bucket
token_ids, _, _ = tokenizer._input_transform(
context=torch.tensor([[huge_value]]), scale=torch.tensor(([1]))
)
assert token_ids[0, 0] == config.n_tokens - 1 # and it's clipped to n_tokens - 1