diff --git a/libdrgn/build-aux/gen_c_keywords_inc_strswitch.py b/libdrgn/build-aux/gen_c_keywords_inc_strswitch.py index 8676e88d9..1201900b5 100755 --- a/libdrgn/build-aux/gen_c_keywords_inc_strswitch.py +++ b/libdrgn/build-aux/gen_c_keywords_inc_strswitch.py @@ -7,6 +7,7 @@ "_Bool", "_Complex", "char", + "class", "const", "double", "enum", @@ -35,12 +36,15 @@ def main() -> None: print(f'\t[{token_kind}] = "{keyword}",') print("};") print() - print("static int identifier_token_kind(const char *s, size_t len)") + print("static int identifier_token_kind(const char *s, size_t len, bool cpp)") print("{") print("\t@memswitch (s, len)@") for token_kind, keyword in keywords: print(f'\t@case "{keyword}"@') - print(f"\t\treturn {token_kind};") + if keyword == "class": + print(f"\t\treturn cpp ? {token_kind} : C_TOKEN_IDENTIFIER;") + else: + print(f"\t\treturn {token_kind};") print("\t@default@") print(f"\t\treturn C_TOKEN_IDENTIFIER;") print("\t@endswitch@") diff --git a/libdrgn/language.h b/libdrgn/language.h index 3a38549ff..a68e97676 100644 --- a/libdrgn/language.h +++ b/libdrgn/language.h @@ -52,7 +52,8 @@ typedef struct drgn_error *drgn_format_object_fn(const struct drgn_object *, size_t, enum drgn_format_object_flags, char **); -typedef struct drgn_error *drgn_find_type_fn(struct drgn_program *prog, +typedef struct drgn_error *drgn_find_type_fn(const struct drgn_language *lang, + struct drgn_program *prog, const char *name, const char *filename, struct drgn_qualified_type *ret); diff --git a/libdrgn/language_c.c b/libdrgn/language_c.c index 3d71277e8..eeb91d0fd 100644 --- a/libdrgn/language_c.c +++ b/libdrgn/language_c.c @@ -1642,6 +1642,7 @@ enum { MAX_QUALIFIER_TOKEN = C_TOKEN_ATOMIC, C_TOKEN_STRUCT, C_TOKEN_UNION, + C_TOKEN_CLASS, C_TOKEN_ENUM, MAX_KEYWORD_TOKEN = C_TOKEN_ENUM, C_TOKEN_LPAREN, @@ -1656,9 +1657,15 @@ enum { #include "c_keywords.inc" -struct drgn_error *drgn_lexer_c(struct drgn_lexer *lexer, +struct drgn_c_family_lexer { + struct drgn_lexer lexer; + bool cpp; +}; + +struct drgn_error *drgn_c_family_lexer_func(struct drgn_lexer *lexer, struct drgn_token *token) { const char *p = lexer->p; + bool cpp = ((struct drgn_c_family_lexer *)lexer)->cpp; while (isspace(*p)) p++; @@ -1698,7 +1705,8 @@ struct drgn_error *drgn_lexer_c(struct drgn_lexer *lexer, p++; } while (isalnum(*p) || *p == '_'); token->kind = identifier_token_kind(token->value, - p - token->value); + p - token->value, + cpp); } else if ('0' <= *p && *p <= '9') { token->kind = C_TOKEN_NUMBER; if (*p++ == '0' && *p == 'x') { @@ -2021,16 +2029,19 @@ static const enum drgn_primitive_type specifier_kind[NUM_SPECIFIER_STATES] = { enum drgn_primitive_type c_parse_specifier_list(const char *s) { struct drgn_error *err; - struct drgn_lexer lexer; + struct drgn_c_family_lexer c_family_lexer; + struct drgn_lexer *lexer = &c_family_lexer.lexer; enum c_type_specifier specifier = SPECIFIER_NONE; enum drgn_primitive_type primitive = DRGN_NOT_PRIMITIVE_TYPE; - drgn_lexer_init(&lexer, drgn_lexer_c, s); + c_family_lexer.cpp = false; + + drgn_lexer_init(lexer, drgn_c_family_lexer_func, s); for (;;) { struct drgn_token token; - err = drgn_lexer_pop(&lexer, &token); + err = drgn_lexer_pop(lexer, &token); if (err) { drgn_error_destroy(err); goto out; @@ -2049,7 +2060,7 @@ enum drgn_primitive_type c_parse_specifier_list(const char *s) primitive = specifier_kind[specifier]; out: - drgn_lexer_deinit(&lexer); + drgn_lexer_deinit(lexer); return primitive; } @@ -2106,6 +2117,7 @@ c_parse_specifier_qualifier_list(struct drgn_program *prog, identifier_len = token.len; } else if (token.kind == C_TOKEN_STRUCT || token.kind == C_TOKEN_UNION || + token.kind == C_TOKEN_CLASS || token.kind == C_TOKEN_ENUM) { if (identifier) { return drgn_error_format(DRGN_ERROR_SYNTAX, @@ -2145,6 +2157,8 @@ c_parse_specifier_qualifier_list(struct drgn_program *prog, kind = DRGN_TYPE_STRUCT; } else if (tag_token == C_TOKEN_UNION) { kind = DRGN_TYPE_UNION; + } else if (tag_token == C_TOKEN_CLASS) { + kind = DRGN_TYPE_CLASS; } else if (tag_token == C_TOKEN_ENUM) { kind = DRGN_TYPE_ENUM; } else if (identifier) { @@ -2451,31 +2465,35 @@ c_type_from_declarator(struct drgn_program *prog, return err; } -static struct drgn_error *c_find_type(struct drgn_program *prog, +static struct drgn_error *c_family_find_type(const struct drgn_language *lang, + struct drgn_program *prog, const char *name, const char *filename, struct drgn_qualified_type *ret) { struct drgn_error *err; - struct drgn_lexer lexer; + struct drgn_c_family_lexer c_family_lexer; + struct drgn_lexer *lexer = &c_family_lexer.lexer; struct drgn_token token; - drgn_lexer_init(&lexer, drgn_lexer_c, name); + c_family_lexer.cpp = lang == &drgn_language_cpp; - err = c_parse_specifier_qualifier_list(prog, &lexer, filename, ret); + drgn_lexer_init(lexer, drgn_c_family_lexer_func, name); + + err = c_parse_specifier_qualifier_list(prog, lexer, filename, ret); if (err) goto out; - err = drgn_lexer_pop(&lexer, &token); + err = drgn_lexer_pop(lexer, &token); if (err) goto out; if (token.kind != C_TOKEN_EOF) { struct c_declarator *outer = NULL, *inner; - err = drgn_lexer_push(&lexer, &token); + err = drgn_lexer_push(lexer, &token); if (err) return err; - err = c_parse_abstract_declarator(prog, &lexer, &outer, &inner); + err = c_parse_abstract_declarator(prog, lexer, &outer, &inner); if (err) { while (outer) { struct c_declarator *next; @@ -2491,7 +2509,7 @@ static struct drgn_error *c_find_type(struct drgn_program *prog, if (err) goto out; - err = drgn_lexer_pop(&lexer, &token); + err = drgn_lexer_pop(lexer, &token); if (err) goto out; if (token.kind != C_TOKEN_EOF) { @@ -2503,26 +2521,29 @@ static struct drgn_error *c_find_type(struct drgn_program *prog, err = NULL; out: - drgn_lexer_deinit(&lexer); + drgn_lexer_deinit(lexer); return err; } -static struct drgn_error *c_bit_offset(struct drgn_program *prog, - struct drgn_type *type, - const char *member_designator, - uint64_t *ret) +static struct drgn_error *c_family_bit_offset(struct drgn_program *prog, + struct drgn_type *type, + const char *member_designator, + uint64_t *ret) { struct drgn_error *err; - struct drgn_lexer lexer; + struct drgn_c_family_lexer c_family_lexer; + struct drgn_lexer *lexer = &c_family_lexer.lexer; int state = INT_MIN; uint64_t bit_offset = 0; - drgn_lexer_init(&lexer, drgn_lexer_c, member_designator); + c_family_lexer.cpp = prog->lang == &drgn_language_cpp; + + drgn_lexer_init(lexer, drgn_c_family_lexer_func, member_designator); for (;;) { struct drgn_token token; - err = drgn_lexer_pop(&lexer, &token); + err = drgn_lexer_pop(lexer, &token); if (err) goto out; @@ -2636,7 +2657,7 @@ static struct drgn_error *c_bit_offset(struct drgn_program *prog, } out: - drgn_lexer_deinit(&lexer); + drgn_lexer_deinit(lexer); return err; } @@ -3421,8 +3442,8 @@ LIBDRGN_PUBLIC const struct drgn_language drgn_language_c = { .format_type_name = c_format_type_name, .format_type = c_format_type, .format_object = c_format_object, - .find_type = c_find_type, - .bit_offset = c_bit_offset, + .find_type = c_family_find_type, + .bit_offset = c_family_bit_offset, .integer_literal = c_integer_literal, .bool_literal = c_bool_literal, .float_literal = c_float_literal, @@ -3450,8 +3471,8 @@ LIBDRGN_PUBLIC const struct drgn_language drgn_language_cpp = { .format_type_name = c_format_type_name, .format_type = c_format_type, .format_object = c_format_object, - .find_type = c_find_type, - .bit_offset = c_bit_offset, + .find_type = c_family_find_type, + .bit_offset = c_family_bit_offset, .integer_literal = c_integer_literal, .bool_literal = c_bool_literal, .float_literal = c_float_literal, diff --git a/libdrgn/lexer.h b/libdrgn/lexer.h index cdb2cec5b..6b4c73fab 100644 --- a/libdrgn/lexer.h +++ b/libdrgn/lexer.h @@ -126,7 +126,7 @@ struct drgn_error *drgn_lexer_peek(struct drgn_lexer *lexer, struct drgn_token *token); /* Exported only for testing. */ -struct drgn_error *drgn_lexer_c(struct drgn_lexer *lexer, +struct drgn_error *drgn_c_family_lexer_func(struct drgn_lexer *lexer, struct drgn_token *token); /** @} */ diff --git a/libdrgn/python/test.c b/libdrgn/python/test.c index c6733f98b..76ab024e8 100644 --- a/libdrgn/python/test.c +++ b/libdrgn/python/test.c @@ -60,7 +60,7 @@ DRGNPY_PUBLIC struct drgn_error *drgn_test_lexer_func(struct drgn_lexer *lexer, DRGNPY_PUBLIC struct drgn_error *drgn_test_lexer_c(struct drgn_lexer *lexer, struct drgn_token *token) { - return drgn_lexer_c(lexer, token); + return drgn_c_family_lexer_func(lexer, token); } DRGNPY_PUBLIC bool drgn_test_path_iterator_next(struct path_iterator *it, diff --git a/libdrgn/type.c b/libdrgn/type.c index 00b7bdb07..d0b266c5a 100644 --- a/libdrgn/type.c +++ b/libdrgn/type.c @@ -1400,7 +1400,8 @@ drgn_program_find_type(struct drgn_program *prog, const char *name, const char *filename, struct drgn_qualified_type *ret) { struct drgn_error *err; - err = drgn_program_language(prog)->find_type(prog, name, filename, ret); + const struct drgn_language *lang = drgn_program_language(prog); + err = lang->find_type(lang, prog, name, filename, ret); if (err != &drgn_not_found) return err; diff --git a/tests/libdrgn.py b/tests/libdrgn.py index dba4dbb14..106af840b 100644 --- a/tests/libdrgn.py +++ b/tests/libdrgn.py @@ -120,6 +120,10 @@ class _drgn_lexer(ctypes.Structure): ] +class _drgn_c_family_lexer(ctypes.Structure): + _fields_ = [("lexer", _drgn_lexer), ("cpp", ctypes.c_bool)] + + drgn_lexer_func = ctypes.CFUNCTYPE( ctypes.POINTER(_drgn_error), ctypes.POINTER(_drgn_lexer), @@ -152,7 +156,7 @@ class _drgn_lexer(ctypes.Structure): ] -drgn_lexer_c = drgn_lexer_func.in_dll(_drgn_cdll, "drgn_test_lexer_c") +drgn_c_family_lexer_func = drgn_lexer_func.in_dll(_drgn_cdll, "drgn_test_lexer_c") drgn_test_lexer_func = drgn_lexer_func.in_dll(_drgn_cdll, "drgn_test_lexer_func") @@ -175,6 +179,7 @@ class C_TOKEN(enum.IntEnum): ATOMIC = auto() STRUCT = auto() UNION = auto() + CLASS = auto() ENUM = auto() LPAREN = auto() RPAREN = auto() @@ -203,10 +208,12 @@ def __repr__(self): class Lexer: - def __init__(self, func, str): - self._lexer = _drgn_lexer() + def __init__(self, func, str, cpp=False): + self._c_family_lexer = _drgn_c_family_lexer() + self._lexer = self._c_family_lexer.lexer self._func = func self._str = str.encode() + self._c_family_lexer.cpp = cpp _drgn_cdll.drgn_test_lexer_init( ctypes.pointer(self._lexer), self._func, self._str ) diff --git a/tests/test_language_c.py b/tests/test_language_c.py index 29839e5a0..ba2c30239 100644 --- a/tests/test_language_c.py +++ b/tests/test_language_c.py @@ -14,7 +14,7 @@ container_of, ) from tests import MockProgramTestCase, TestCase -from tests.libdrgn import C_TOKEN, Lexer, drgn_lexer_c +from tests.libdrgn import C_TOKEN, Lexer, drgn_c_family_lexer_func class TestPrettyPrintTypeName(MockProgramTestCase): @@ -711,8 +711,8 @@ def test_function_no_name(self): class TestLexer(TestCase): - def lex(self, s): - lexer = Lexer(drgn_lexer_c, s) + def lex(self, s, cpp=False): + lexer = Lexer(drgn_c_family_lexer_func, s, cpp) while True: token = lexer.pop() if token.kind == C_TOKEN.EOF: @@ -720,7 +720,7 @@ def lex(self, s): yield token def test_empty(self): - lexer = Lexer(drgn_lexer_c, "") + lexer = Lexer(drgn_c_family_lexer_func, "") for i in range(64): self.assertEqual(lexer.pop().kind, C_TOKEN.EOF) @@ -738,7 +738,7 @@ def test_symbols(self): def test_keywords(self): s = """void char short int long signed unsigned _Bool float double - _Complex const restrict volatile _Atomic struct union enum""" + _Complex const restrict volatile _Atomic struct union class enum""" tokens = [ C_TOKEN.VOID, C_TOKEN.CHAR, @@ -757,10 +757,37 @@ def test_keywords(self): C_TOKEN.ATOMIC, C_TOKEN.STRUCT, C_TOKEN.UNION, + C_TOKEN.IDENTIFIER, C_TOKEN.ENUM, ] self.assertEqual([token.kind for token in self.lex(s)], tokens) + def test_cpp_keywords(self): + s = """void char short int long signed unsigned _Bool float double + _Complex const restrict volatile _Atomic struct union class enum""" + tokens = [ + C_TOKEN.VOID, + C_TOKEN.CHAR, + C_TOKEN.SHORT, + C_TOKEN.INT, + C_TOKEN.LONG, + C_TOKEN.SIGNED, + C_TOKEN.UNSIGNED, + C_TOKEN.BOOL, + C_TOKEN.FLOAT, + C_TOKEN.DOUBLE, + C_TOKEN.COMPLEX, + C_TOKEN.CONST, + C_TOKEN.RESTRICT, + C_TOKEN.VOLATILE, + C_TOKEN.ATOMIC, + C_TOKEN.STRUCT, + C_TOKEN.UNION, + C_TOKEN.CLASS, + C_TOKEN.ENUM, + ] + self.assertEqual([token.kind for token in self.lex(s, cpp=True)], tokens) + def test_identifiers(self): s = "_ x foo _bar baz1" tokens = s.split() diff --git a/tests/test_program.py b/tests/test_program.py index 53ea9d6c0..0634096a7 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -18,6 +18,7 @@ ProgramFlags, Qualifiers, TypeKind, + TypeMember, host_platform, ) from tests import ( @@ -548,6 +549,33 @@ def test_tagged_type(self): self.assertIdentical(self.prog.type("union option"), self.option_type) self.assertIdentical(self.prog.type("enum color"), self.color_type) + def test_class_type(self): + struct_class = self.prog.struct_type( + "class", + 8, + (TypeMember(self.prog.pointer_type(self.prog.void_type()), "ptr"),), + ) + class_point = self.prog.class_type( + "Point", + 8, + ( + TypeMember(self.prog.int_type("int", 4, True), "x", 0), + TypeMember(self.prog.int_type("int", 4, True), "y", 32), + ), + ) + self.types.append(struct_class) + self.types.append(class_point) + self.prog.language = Language.C + self.assertIdentical(self.prog.type("struct class"), struct_class) + self.prog.language = Language.CPP + self.assertRaisesRegex( + SyntaxError, + "expected identifier after 'struct'", + self.prog.type, + "struct class", + ) + self.assertIdentical(self.prog.type("class Point"), class_point) + def test_typedef(self): self.types.append(self.pid_type) self.assertIdentical(self.prog.type("pid_t"), self.pid_type)