-
-
Notifications
You must be signed in to change notification settings - Fork 88
/
autogen_protobuf_extensions.py
executable file
·152 lines (111 loc) · 4.26 KB
/
autogen_protobuf_extensions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
#
# This is a quick-hack that extends ProtocolMessage with a method called inner
# that will return the correct submessage based on the type.
#
# Update with this command:
# ./scripts/autogen_protobuf_extensions.py > pyatv/mrp/protobuf/__init__.py
#
"""Simple hack to auto-generate protobuf handling code."""
import sys
import os
from collections import namedtuple
# New messages re-using inner message of another type
REUSED_MESSAGES = {'DEVICE_INFO_MESSAGE': 'DEVICE_INFO_UPDATE_MESSAGE'}
BASE_PACKAGE = 'pyatv.mrp.protobuf'
OUTPUT_TEMPLATE = """\"\"\"Simplified extension handling for protobuf messages.
THIS CODE IS AUTO-GENERATED - DO NOT EDIT!!!
\"\"\"
from pyatv.mrp.protobuf.ProtocolMessage_pb2 import ProtocolMessage
{packages}
{messages}
_EXTENSION_LOOKUP = {{
{extensions}
}}
{constants}
def _inner_message(self):
extension = _EXTENSION_LOOKUP.get(self.type, None)
if extension:
return self.Extensions[extension]
raise Exception('unknown type: ' + str(self.type))
ProtocolMessage.inner = _inner_message # type: ignore
"""
MessageInfo = namedtuple('MessageInfo',
['module', 'title', 'accessor', 'const'])
def extract_message_info():
"""Get information about all messages of interest."""
base_path = BASE_PACKAGE.replace('.', '/')
filename = os.path.join(base_path, 'ProtocolMessage.proto')
with open(filename, 'r') as file:
types_found = False
for line in file:
stripped = line.lstrip().rstrip()
# Look for the Type enum
if stripped == 'enum Type {':
types_found = True
continue
elif types_found and stripped == '}':
break
elif not types_found:
continue
constant = stripped.split(' ')[0]
title = constant.title().replace(
'_', '').replace('Hid', 'HID') # Hack...
accessor = title[0].lower() + title[1:]
if not os.path.exists(os.path.join(base_path, title + '.proto')):
continue
yield MessageInfo(
title + '_pb2', title, accessor, constant)
def extract_unreferenced_messages():
"""Get messages not referenced anywhere."""
base_path = BASE_PACKAGE.replace('.', '/')
for filename in os.listdir(base_path):
tmp = os.path.splitext(filename)
if tmp[1] != '.proto' or tmp[0] == 'ProtocolMessage':
continue
with open(os.path.join(base_path, filename)) as file:
for line in file:
if line.startswith('message'):
yield tmp[0] + '_pb2', line.split(' ')[1]
def main():
"""Script starts somewhere around here."""
message_names = set()
packages = []
messages = []
extensions = []
constants = []
# Extract everything needed to generate output file
for info in extract_message_info():
message_names.add(info.title)
packages.append(
'from {0} import {1}'.format(
BASE_PACKAGE, info.module))
messages.append(
'from {0}.{1} import {2}'.format(
BASE_PACKAGE, info.module, info.title))
extensions.append(
'ProtocolMessage.{0}: {1}.{2},'.format(
info.const, info.module, info.accessor))
constants.append(
'{0} = ProtocolMessage.{0}'.format(
info.const))
reused = REUSED_MESSAGES.get(info.const)
if reused:
extensions.append(
'ProtocolMessage.{0}: {1}.{2},'.format(
reused, info.module, info.accessor))
# Look for remaining messages
for module_name, message_name in extract_unreferenced_messages():
if message_name not in message_names:
message_names.add(message_name)
messages.append('from {0}.{1} import {2}'.format(
BASE_PACKAGE, module_name, message_name))
# Print file output with values inserted
print(OUTPUT_TEMPLATE.format(
packages='\n'.join(sorted(packages)),
messages='\n'.join(sorted(messages)),
extensions='\n '.join(sorted(extensions)),
constants='\n'.join(sorted(constants))))
return 0
if __name__ == '__main__':
sys.exit(main())