-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathxml_dict.py
126 lines (95 loc) · 3.45 KB
/
xml_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import re
import sys
import itertools
import functools
from dataclasses import dataclass, field
strip_namespace = lambda s: s.split("}")[-1]
@dataclass
class xml_node:
tag : str
attributes : dict = field(default_factory=dict)
text : str = ""
namespaces : dict = field(default=None, repr=False)
children : list = field(default_factory=list, repr=False)
parent : object = field(default=None, repr=False)
def __hash__(self):
return id(self)
def child_with_tag(self, tag):
cs = list(self.children_with_tag(tag))
if cs:
return cs[0]
def children_with_tag(self, tag):
for ch in self.children:
if ch.tag == tag:
yield ch
def apply(self, fn):
new_self = fn(self)
children = [c.apply(fn) for c in self.children]
new_self.children = children
for c in children:
c.parent = new_self
return new_self
def strip_namespaces(self):
def inner(node):
return xml_node(
strip_namespace(node.tag),
{strip_namespace(k): v for k, v in node.attributes.items()},
node.text,
node.namespaces
)
return self.apply(inner)
def recursive_print(self, file=sys.stdout, level=0):
attr_pairs = "".join(f' {k}="{v}"' for k, v in self.attributes.items())
close = " /" if not (self.text or self.children) else ""
indent = " " * (level*2)
print(f"{indent}<{self.tag}{attr_pairs}{close}>", file=file)
if self.text or self.children:
if self.text:
print(f"{indent} {self.text}", file=file)
for c in self.children:
c.recursive_print(file=file, level=level+1)
pass
print(f"{indent}</{self.tag}>", file=file)
import lxml.etree as ET
IGNORED_ATTRS = ()
IGNORED_TAGS = ()
def flatmap(func, *iterable):
return itertools.chain.from_iterable(map(func, *iterable))
def to_data(t, parent=None):
if not isinstance(t.tag, str):
# Comment/Entity/...
return
if t.tag in IGNORED_TAGS:
return
children = list(flatmap(to_data, t))
attributes = {k: v for k, v in (t.attrib or {}).items() if k not in IGNORED_ATTRS}
text = ""
if t.text and t.text.strip():
text = t.text.strip()
nd = xml_node(t.tag, attributes, text, t.nsmap, children)
for ch in children:
ch.parent = nd
yield nd
def read(fn):
try:
parser = ET.XMLParser(encoding="windows-1252")
return next(to_data(ET.parse(fn, parser=parser).getroot()))
except:
parser = ET.XMLParser(encoding="utf-8")
return next(to_data(ET.parse(fn, parser=parser).getroot()))
def serialize(di, ofn):
def inner(d, parent=None):
if parent is None:
node = ET.Element(d.tag, nsmap=d.namespaces)
else:
node = ET.SubElement(parent, d.tag, nsmap=d.namespaces)
if d.text:
node.text = d.text
for k, v in d.attributes.items():
node.set(k, v)
for child in d.children:
inner(child, node)
return node
etree = ET.ElementTree(inner(di[0]))
ET.indent(etree, "\t")
etree.write(ofn, pretty_print=True, xml_declaration=True)