-
Notifications
You must be signed in to change notification settings - Fork 246
/
Copy pathrun.cpp
949 lines (835 loc) · 30.6 KB
/
run.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
/*
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
*/
/* Inference for Llama-2 Transformer model in pure C++ */
#include "sentencepiece.h"
#include "tiktoken.h"
#include <algorithm>
#include <cinttypes>
#include <cstdint>
#include <cstdlib>
#include <ctype.h>
#include <iterator>
#include <math.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <time.h>
#ifdef DEBUG
#include <cassert>
#include <iostream>
#endif
#if defined(__AOTI_MODEL__) || (defined(__ET_MODEL__) && defined(USE_ATENLIB))
#include <torch/torch.h>
#endif
#ifdef __AOTI_MODEL__
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
torch::Device aoti_device(torch::kCPU);
#else // __ET_MODEL__
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor_ptr.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#if defined(ET_USE_ADAPTIVE_THREADS)
#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <executorch/extension/threadpool/threadpool.h>
#endif
using exec_aten::ScalarType;
using executorch::extension::make_tensor_ptr;
using executorch::extension::TensorPtr;
using torch::executor::EValue;
using torch::executor::Module;
using torch::executor::Result;
using executorch::runtime::MemoryManager;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::Error;
#endif
using tokenizers::SPTokenizer;
using tokenizers::Tiktoken;
using tokenizers::Tokenizer;
#define UNWRAP(x) \
({ \
if (!(x).ok()) { \
fprintf(stderr, "Got error code % " PRIu32, x.error()); \
exit(EXIT_FAILURE); \
} \
std::move(x.get()); \
})
// ----------------------------------------------------------------------------
// Transformer model
enum ModelType {
UNKNOWN_MODEL = 0,
LLAMA2_MODEL = 2,
LLAMA3_MODEL = 3,
};
ModelType get_model_type(int model_int) {
switch (model_int) {
case 2:
return LLAMA2_MODEL;
break;
case 3:
return LLAMA3_MODEL;
break;
default:
return UNKNOWN_MODEL;
}
}
typedef struct {
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;
typedef struct {
float *logits; // output logits
int64_t *toks; // tokens seen so far; no kv-cache :(
} RunState;
typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
RunState state; // buffers for the "wave" of activations in the forward pass
std::unordered_map<std::string, std::string> metadata;
#ifdef __AOTI_MODEL__
torch::inductor::AOTIModelPackageLoader *runner;
#else // __ET_MODEL__
Module *runner;
#endif
} Transformer;
void malloc_run_state(RunState *s, Config *p) {
// we calloc instead of malloc to keep valgrind happy
s->logits = (float *)calloc(p->vocab_size, sizeof(float));
s->toks = (int64_t *)calloc(p->seq_len, sizeof(int64_t));
if (!s->logits || !s->toks) {
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
}
void free_run_state(RunState *s) {
free(s->logits);
free(s->toks);
}
void read_checkpoint(char *checkpoint, Config *config) {
FILE *file = fopen(checkpoint, "rb");
if (!file) {
fprintf(stderr, "Couldn't open file %s\n", checkpoint);
exit(EXIT_FAILURE);
}
// read in the config header
if (fread(config, sizeof(Config), 1, file) != 1) {
exit(EXIT_FAILURE);
}
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config->vocab_size > 0 ? 1 : 0;
config->vocab_size = abs(config->vocab_size);
}
void build_transformer(Transformer *t, char *model_path) {
#ifdef __AOTI_MODEL__
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
#else //__ET_MODEL__
t->runner = new Module(
/* path to PTE model */ model_path,
/* PTE mmap settings */ Module::LoadMode::MmapUseMlockIgnoreErrors);
#endif
}
void free_transformer(Transformer *t) {
// free the RunState buffers
free_run_state(&t->state);
delete t->runner;
}
// ----------------------------------------------------------------------------
// neural net blocks; the dynamics of the Transformer
void softmax(float *x, int size) {
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
// normalize
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
float *forward(Transformer *transformer, int token, int pos) {
Config *p = &transformer->config;
RunState *s = &transformer->state;
s->toks[pos] = token;
long token_buffer[1] = {token};
long pos_buffer[1] = {pos};
#ifdef DEBUG
std::cerr << "token: " << token << " pos: " << pos << "\n";
#endif
#ifdef __AOTI_MODEL__
torch::Tensor token_tensor =
torch::from_blob(token_buffer, {1, 1}, torch::kLong);
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
std::vector<torch::Tensor> inputs{token_tensor.to(aoti_device),
pos_tensor.to(aoti_device)};
torch::Tensor result = transformer->runner->run(inputs)[0]
.to(torch::dtype(torch::kFloat32))
.to(torch::kCPU);
auto logits = result[0].data_ptr();
memcpy(s->logits, logits, p->vocab_size * sizeof(float));
#else // __ET_MODEL__
TensorPtr pos_managed = make_tensor_ptr({1}, pos_buffer, ScalarType::Long);
TensorPtr tokens_managed =
make_tensor_ptr({1, 1}, token_buffer, ScalarType::Long);
std::vector<EValue> inputs;
auto tmp1 = EValue(tokens_managed);
auto tmp2 = EValue(pos_managed);
inputs.push_back(tmp1);
inputs.push_back(tmp2);
Result<std::vector<EValue>> outputs_res =
transformer->runner->forward(inputs);
if (!outputs_res.ok()) {
fprintf(stderr, "Executorch forward() failed.");
exit(EXIT_FAILURE);
}
std::vector<EValue> result = outputs_res.get();
// HACK: the rest of this runner assumes that logits must be float,
// so we simply convert them rather than plumbing
// templating/switch-on-type through the rest of this file.
const auto &result_tensor = result[0].toTensor();
ET_SWITCH_REALHBBF16_TYPES(
result_tensor.scalar_type(), unused, "forward", CTYPE, [&]() {
const CTYPE *logits = result_tensor.const_data_ptr<CTYPE>();
std::transform(logits, logits + p->vocab_size, s->logits,
[](auto x) { return static_cast<float>(x); });
});
#endif
return s->logits;
}
// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
typedef struct {
int vocab_size;
ProbIndex *probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;
int sample_argmax(float *probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
float max_p = probabilities[0];
for (int i = 1; i < n; i++) {
if (probabilities[i] > max_p) {
max_i = i;
max_p = probabilities[i];
}
}
return max_i;
}
int sample_mult(float *probabilities, int n, float coin) {
// sample index from probabilities (they must sum to 1!)
// coin is a random number in [0, 1), usually from random_f32()
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (coin < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}
int compare(const void *a, const void *b) {
ProbIndex *a_ = (ProbIndex *)a;
ProbIndex *b_ = (ProbIndex *)b;
if (a_->prob > b_->prob)
return -1;
if (a_->prob < b_->prob)
return 1;
return 0;
}
int sample_topp(float *probabilities, int n, float topp, ProbIndex *probindex,
float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
// coin is a random number in [0, 1), usually from random_f32()
int n0 = 0;
// quicksort indices in descending order of probabilities
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
const float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < n; i++) {
if (probabilities[i] >= cutoff) {
probindex[n0].index = i;
probindex[n0].prob = probabilities[i];
n0++;
}
}
qsort(probindex, n0, sizeof(ProbIndex), compare);
// truncate the list where cumulative probability exceeds topp
float cumulative_prob = 0.0f;
int last_idx = n0 - 1; // in case of rounding errors consider all elements
for (int i = 0; i < n0; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > topp) {
last_idx = i;
break; // we've exceeded topp by including last_idx
}
}
// sample from the truncated list
float r = coin * cumulative_prob;
float cdf = 0.0f;
for (int i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
if (r < cdf) {
return probindex[i].index;
}
}
return probindex[last_idx].index; // in case of rounding errors
}
void build_sampler(Sampler *sampler, int vocab_size, float temperature,
float topp, unsigned long long rng_seed) {
sampler->vocab_size = vocab_size;
sampler->temperature = temperature;
sampler->topp = topp;
sampler->rng_state = rng_seed;
// buffer only used with nucleus sampling; may not need but it's ~small
sampler->probindex =
(ProbIndex *)malloc(sampler->vocab_size * sizeof(ProbIndex));
}
void free_sampler(Sampler *sampler) { free(sampler->probindex); }
unsigned int random_u32(unsigned long long *state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(unsigned long long *state) { // random float32 in [0,1)
return (random_u32(state) >> 8) / 16777216.0f;
}
int sample(Sampler *sampler, float *logits) {
// sample the token given the logits and some hyperparameters
int next;
if (sampler->temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, sampler->vocab_size);
} else {
// apply the temperature to the logits
for (int q = 0; q < sampler->vocab_size; q++) {
logits[q] /= sampler->temperature;
}
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// flip a (float) coin (this is our source of entropy for sampling)
float coin = random_f32(&sampler->rng_state);
// we sample from this distribution to get the next token
if (sampler->topp <= 0 || sampler->topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size, coin);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, sampler->topp,
sampler->probindex, coin);
}
}
return next;
}
Tokenizer *build_tokenizer(const char *tokenizer_path, ModelType model_type) {
Tokenizer *tokenizer = NULL;
switch (model_type) {
case LLAMA2_MODEL:
tokenizer = new SPTokenizer();
tokenizer->load(tokenizer_path);
break;
case LLAMA3_MODEL:
tokenizer = new Tiktoken();
tokenizer->load(tokenizer_path);
break;
default:
fprintf(stderr, "No tokenizer defined for model type %d.\n", model_type);
exit(EXIT_FAILURE);
}
return tokenizer;
}
void free_tokenizer(Tokenizer *tokenizer) { delete tokenizer; }
// ----------------------------------------------------------------------------
// utilities: time
void safe_printf(const char *piece) {
// piece might be a raw byte token, and we only want to print printable chars
// or whitespace because some of the other bytes can be various control codes,
// backspace, etc.
if (piece == NULL) {
return;
}
if (piece[0] == '\0') {
return;
}
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
if (!(isprint(byte_val) || isspace(byte_val))) {
return; // bad byte, don't print it
}
}
printf("%s", piece);
}
long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}
// ----------------------------------------------------------------------------
// generation loop
// Prints decoded tokens generated from the transformer.
// The first token is not printed and is assumed to be a BOS or other similar
// token
unsigned generate_from_prompt_tokens(Transformer *transformer,
Tokenizer *tokenizer, Sampler *sampler,
const std::vector<uint64_t> &prompt_tokens,
unsigned pos,
const std::vector<uint64_t> &stop_tokens,
int stop_pos, bool print_prompt,
bool print_tok_per_sec) {
if (prompt_tokens.size() == 0) {
return pos;
}
uint64_t next; // will store the next token in the sequence
uint64_t token; // stores the current token to feed into the transformer
bool done_with_prompt; // whether we are done processing prompt
bool found_stop_token = false; // whether we've found the stop_token after
// processing prompt_tokens
unsigned pos_in_prompt = 0; // position relative to start of prompt
long start = 0; // timer start (initialized after first token)
// If stop_pos == -1, we go until we find stop_token
// If stop_pos >= 0, we go until we find stop_token or pos <= stop_pos.
while (!found_stop_token && (stop_pos == -1 || pos <= stop_pos)) {
// Get token and next
if (pos_in_prompt < prompt_tokens.size()) {
// Token comes from prompt
token = prompt_tokens[pos_in_prompt++];
float *logits = forward(transformer, token, pos);
// Next token is either from prompt or if on last
// prompt token, next is sampled
if (pos_in_prompt < prompt_tokens.size()) {
next = prompt_tokens[pos_in_prompt];
} else {
next = sample(sampler, logits);
}
} else {
// Token comes from next sampled from previous round.
token = next;
float *logits = forward(transformer, token, pos);
next = sample(sampler, logits);
}
done_with_prompt = (pos_in_prompt >= prompt_tokens.size());
// we terminate on finding the stop_token if we are done processing the
// prompt (stop_tokens in the prompt do not terminate the loop)
if (done_with_prompt && (std::find(stop_tokens.begin(), stop_tokens.end(),
token) != stop_tokens.end())) {
found_stop_token = true;
}
// We print next in each iteration of the loop, not token
if (!found_stop_token && (print_prompt || done_with_prompt)) {
// The stop_token is printed as newline
bool next_is_stop = std::find(stop_tokens.begin(), stop_tokens.end(),
next) != stop_tokens.end();
if (next_is_stop) {
printf("\n");
} else {
std::string piece = UNWRAP(tokenizer->decode(token, next));
safe_printf(piece.c_str()); // same as printf("%s", piece), but skips
// "unsafe" bytes
fflush(stdout);
}
}
// init the timer here because the first iteration can be slower
if (pos == 0) {
start = time_in_ms();
}
pos++;
}
// report achieved tok/s (pos-1 because the timer starts after first
// iteration)
if (print_tok_per_sec && pos > 1) {
long end = time_in_ms();
fprintf(stderr, "\n\nachieved tok/s: %f\n",
(pos - 1) / (double)(end - start) * 1000);
}
return pos;
}
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
const char *prompt, int steps, ModelType model_type) {
const char *default_prompt = "Once upon a time";
if (prompt == NULL) {
prompt = default_prompt;
}
if (steps == 0) {
return;
}
std::vector<uint64_t> prompt_tokens;
std::vector<uint64_t> stop_tokens;
switch (model_type) {
case LLAMA2_MODEL:
prompt_tokens = UNWRAP(tokenizer->encode(prompt, 1, 0));
stop_tokens.push_back(tokenizer->eos_tok());
break;
case LLAMA3_MODEL:
prompt_tokens = UNWRAP(tokenizer->encode(prompt, 1, 0));
stop_tokens.push_back(
UNWRAP(tokenizer->encode("<|end_of_text|>", 0, 0))[0]);
stop_tokens.push_back(UNWRAP(tokenizer->encode("<|eot_id|>", 0, 0))[0]);
break;
default:
fprintf(stderr, "Generate does not support model type %d.\n", model_type);
exit(EXIT_FAILURE);
}
generate_from_prompt_tokens(transformer, tokenizer, sampler, prompt_tokens,
/*pos=*/0,
/*stop_tokens=*/stop_tokens,
/*stop_pos=*/steps - 1,
/*print_prompt=*/true,
/*print_tok_per_sec=*/true);
}
void read_stdin(const char *guide, char *buffer, size_t bufsize) {
// read a line from stdin, up to but not including \n
printf("%s", guide);
if (fgets(buffer, bufsize, stdin) != NULL) {
size_t len = strlen(buffer);
if (len > 0 && buffer[len - 1] == '\n') {
buffer[len - 1] = '\0'; // strip newline
}
}
}
// ----------------------------------------------------------------------------
// chat loop
// I manually inspected the tokens for a few chat conversations compared to
// python reference and that seemed ok, but this was not thoroughly tested and
// is not safely implemented, it's more a proof of concept atm.
std::vector<uint64_t> get_initial_prompt_tokens(const char *cli_system_prompt,
const char *cli_user_prompt,
Tokenizer *tokenizer,
ModelType model_type) {
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[512 * 2 + 200]; // the prompt template is ~170
// characters. We use 200 to be safe.
if (cli_system_prompt != NULL) {
strcpy(system_prompt, cli_system_prompt);
} else {
read_stdin("Enter system prompt (optional): ", system_prompt,
sizeof(system_prompt));
}
if (cli_user_prompt != NULL) {
strcpy(user_prompt, cli_user_prompt);
} else {
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}
std::vector<uint64_t> tokens;
switch (model_type) {
case LLAMA2_MODEL:
if (system_prompt[0] != '\0') {
snprintf(rendered_prompt, sizeof(rendered_prompt) - 1,
"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]", system_prompt,
user_prompt);
} else {
snprintf(rendered_prompt, sizeof(rendered_prompt) - 1,
"[INST] %s [/INST]", user_prompt);
}
// We need to add BOS token here and not in template because llama2
// tokenizer does not pattern match special tokens
tokens = UNWRAP(tokenizer->encode(rendered_prompt, 1, 0));
break;
case LLAMA3_MODEL:
if (system_prompt[0] != '\0') {
snprintf(rendered_prompt, sizeof(rendered_prompt) - 1,
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
"\n\n%s<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<"
"|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
system_prompt, user_prompt);
} else {
snprintf(rendered_prompt, sizeof(rendered_prompt) - 1,
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%"
"s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
user_prompt);
}
tokens = UNWRAP(tokenizer->encode(rendered_prompt, 0, 0));
break;
default:
fprintf(stderr, "Chat does not support model type %d.\n", model_type);
exit(EXIT_FAILURE);
}
#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif
return tokens;
}
std::vector<uint64_t> get_next_user_prompt_tokens(Tokenizer *tokenizer,
ModelType model_type) {
char user_prompt[512];
char rendered_prompt[512 + 150]; // the prompt template is ~100 characters. We
// use 150 to be safe.
read_stdin("User: ", user_prompt, sizeof(user_prompt));
std::vector<uint64_t> tokens;
switch (model_type) {
case LLAMA2_MODEL:
snprintf(rendered_prompt, sizeof(rendered_prompt) - 1, "[INST] %s [/INST]",
user_prompt);
// We need to add BOS token here and not in template because llama2
// tokenizer does not pattern match special tokens
tokens = UNWRAP(tokenizer->encode(rendered_prompt, /*bos*/ 1, /*eos*/ 0));
break;
case LLAMA3_MODEL:
snprintf(rendered_prompt, sizeof(rendered_prompt) - 1,
"<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_"
"header_id|>assistant<|end_header_id|>\n\n",
user_prompt);
tokens = UNWRAP(tokenizer->encode(rendered_prompt, 0, 0));
break;
default:
fprintf(stderr, "Chat does not support model type %d.\n", model_type);
exit(EXIT_FAILURE);
}
#ifdef DEBUG
std::cerr << "Start of rendered prompt:" << std::endl;
std::cerr << rendered_prompt;
std::cerr << "End of rendered prompt:" << std::endl;
std::cerr << "Encoded prompt: ";
for (int i = 0; i < tokens.size(); i++) {
std::cerr << tokens[i] << ", ";
}
std::cerr << std::endl << std::flush;
#endif
return tokens;
}
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
const char *cli_user_prompt, const char *cli_system_prompt,
unsigned steps, ModelType model_type) {
if (steps == 0) {
return;
}
uint64_t eot_token;
std::vector<uint64_t> prompt_tokens;
switch (model_type) {
case LLAMA2_MODEL:
// llama2 uses EOS as EOT token
eot_token = tokenizer->eos_tok();
break;
case LLAMA3_MODEL:
eot_token = UNWRAP(tokenizer->encode("<|eot_id|>", 0, 0))[0];
break;
default:
fprintf(stderr, "Chat does not support model type %d.\n", model_type);
exit(EXIT_FAILURE);
}
std::vector<uint64_t> stop_tokens{eot_token};
unsigned pos = 0;
while (pos < steps) {
if (pos == 0) {
prompt_tokens = get_initial_prompt_tokens(
cli_system_prompt, cli_user_prompt, tokenizer, model_type);
} else {
prompt_tokens = get_next_user_prompt_tokens(tokenizer, model_type);
}
printf("Assistant: ");
pos = generate_from_prompt_tokens(
transformer, tokenizer, sampler, prompt_tokens, pos,
/*stop_tokens=*/stop_tokens,
/*stop_pos=*/steps - 1, // We could pass in -1 here if we do not want
// the model to stop mid-reply
/*print_prompt=*/false,
/*print_tok_per_sec=*/false);
}
}
// ----------------------------------------------------------------------------
// CLI, include only if not testing
#ifndef TESTING
void error_usage() {
fprintf(stderr, "Usage: run <model_path> [options]\n");
fprintf(stderr,
"Example: run model.{so,pte} -n 256 -i \"Once upon a time\"\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1], "
"default 0.9\n");
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = "
"max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
fprintf(stderr, " -z <string> path to tokenizer\n");
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
fprintf(stderr,
" -v <int> (optional) vocab size, default is model-specific.\n");
fprintf(stderr,
" -l <int> (optional) llama version (2 or 3), default 2.\n");
exit(EXIT_FAILURE);
}
int main(int argc, char *argv[]) {
// default parameters
char *model_path = NULL;
char *tokenizer_path = NULL;
float temperature =
1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well,
// but slower
int steps = 128; // number of steps to run for
const char *prompt = NULL; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
const char *mode = "generate"; // generate|chat
char *system_prompt =
NULL; // the (optional) system prompt to use in chat mode
int vocab_size = -1;
int llama_ver = 2;
#if defined(ET_USE_ADAPTIVE_THREADS)
uint32_t num_performant_cores =
torch::executorch::cpuinfo::get_num_performant_cores();
if (num_performant_cores > 0) {
torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool(
num_performant_cores);
}
#endif
// poor man's C argparse so we can override the defaults above from the
// command line
if (argc >= 2) {
model_path = argv[1];
} else {
error_usage();
}
for (int i = 2; i < argc; i += 1) {
// do some basic validation
char *parm = argv[i+1];
// uniarg means the arg comes right after the letter in accordance with posix
int uniarg = strlen(argv[i]) > 2;
if (argv[i][0] != '-') {
error_usage();
} // must start with dash
if (strlen(argv[i]) < 2) {
error_usage();
} // must have at least dash '-' and option letter
if (uniarg) {
parm=&argv[i][2];
} else if (i + 1 >= argc) {
error_usage();
} // must have arg after option if flag is not contiguous to option
// read in the args
if (argv[i][1] == 't') {
temperature = atof(parm);
} else if (argv[i][1] == 'p') {
topp = atof(parm);
} else if (argv[i][1] == 's') {
rng_seed = atoi(parm);
} else if (argv[i][1] == 'n') {
steps = atoi(parm);
} else if (argv[i][1] == 'v') {
vocab_size = atoi(parm);
} else if (argv[i][1] == 'i') {
prompt = parm;
} else if (argv[i][1] == 'z') {
tokenizer_path = parm;
} else if (argv[i][1] == 'm') {
mode = parm;
} else if (argv[i][1] == 'y') {
system_prompt = parm;
} else if (argv[i][1] == 'l') {
llama_ver = atoi(parm);
} else {
error_usage();
}
// account for parameter
i += (uniarg)?0:1;
}
if (model_path == NULL) {
fprintf(stderr, "No model_path provided.");
error_usage();
}
Transformer transformer;
build_transformer(&transformer, model_path);
#ifdef __AOTI_MODEL__
auto aoti_metadata = transformer.runner->get_metadata();
aoti_device = aoti_metadata["AOTI_DEVICE_KEY"] == "cpu"
? torch::Device(torch::kCPU)
: torch::Device(torch::kCUDA);
ModelType model_type = get_model_type(std::stoi(aoti_metadata["tokenizer_type"]));
#else // __ET_MODEL__
Error load_status = transformer.runner->load();
ET_CHECK_MSG(
load_status == torch::executor::Error::Ok,
"program::load() failed with status 0x%" PRIx32,
static_cast<uint32_t>(load_status));
static std::array<uint8_t, 4 * 1024U * 1024U> method_allocator_pool; // 4MB
MemoryAllocator method_allocator{MemoryAllocator(
sizeof(method_allocator_pool), method_allocator_pool.data())};
MemoryManager memory_manager(&method_allocator, nullptr);
auto tokenizer_method = transformer.runner->program()->load_method("tokenizer_type", &memory_manager);
Error execute_status = tokenizer_method->execute();
ET_CHECK_MSG(
execute_status == torch::executor::Error::Ok,
"method::execute() failed with status 0x%" PRIx32,
static_cast<uint32_t>(execute_status));
auto tokenizer_type = tokenizer_method->get_output(0).toInt();
ModelType model_type = get_model_type(tokenizer_type);
#endif
if (model_type == UNKNOWN_MODEL) {
fprintf(stderr, "Unknown model type passed by -l argument. Received l=%d.",
llama_ver);
error_usage();
}
if (tokenizer_path == NULL) {
fprintf(stderr, "No tokenizer_path provided.");
error_usage();
}
// parameter validation/overrides
if (rng_seed <= 0)
rng_seed = (unsigned int)time(NULL);
if (temperature < 0.0)
temperature = 0.0;
if (topp < 0.0 || 1.0 < topp)
topp = 0.9;
if (steps < 0)
steps = 0;
Tokenizer *tokenizer = build_tokenizer(tokenizer_path, model_type);
// If no tokenizer path provided, get default for model_type
if (vocab_size == -1) {
vocab_size = tokenizer->vocab_size();
}
// read in the Config and the Weights from the model
// read_checkpoint(model_path, &t->config);
// allocate the RunState buffers
transformer.config.vocab_size = vocab_size;
transformer.config.seq_len = steps;
malloc_run_state(&transformer.state, &transformer.config);
Sampler sampler;
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);
if (strcmp(mode, "generate") == 0) {
generate(&transformer, tokenizer, &sampler, prompt, steps, model_type);
} else if (strcmp(mode, "chat") == 0) {
chat(&transformer, tokenizer, &sampler, prompt, system_prompt, steps,
model_type);
} else {
fprintf(stderr, "unknown mode: %s\n", mode);
error_usage();
}
// memory and file handles cleanup
free_sampler(&sampler);
free_tokenizer(tokenizer);
free_transformer(&transformer);
return 0;
}
#endif