Skip to content

Commit

Permalink
Merge pull request #1506 from ashishpriyadarshiCIC/test-script-xml
Browse files Browse the repository at this point in the history
Test script xml
  • Loading branch information
henrykironde committed Sep 4, 2020
2 parents 1e02815 + 4613934 commit b3a62e2
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 33 deletions.
15 changes: 11 additions & 4 deletions retriever/engines/xmlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from retriever.lib.dummy import DummyConnection
from retriever.lib.models import Engine
from retriever.lib.tools import open_fr, open_fw
from retriever.lib.engine_tools import xml2csv, sort_csv
from retriever.lib.engine_tools import sort_csv, xml2csv_test


class engine(Engine):
Expand Down Expand Up @@ -121,9 +121,16 @@ def to_csv(self, sort=True, path=None, select_columns=None):
os.path.join(
path if path else '',
os.path.splitext(os.path.basename(table_item[0]))[0] + '.csv'))
csv_outfile = xml2csv(table_item[0],
outputfile=outputfile,
header_values=header)
empty_rows = 1
if hasattr(self.script, "empty_rows"):
empty_rows = self.script.empty_rows
input_file = table_item[0]
header_values = header

csv_outfile = xml2csv_test(input_file,
outputfile,
header_values,
row_tag="row")
sort_csv(csv_outfile, encoding=self.encoding)

def get_connection(self):
Expand Down
10 changes: 10 additions & 0 deletions retriever/lib/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from retriever.lib.engine_tools import geojson2csv
from retriever.lib.engine_tools import sqlite2csv
from retriever.lib.engine_tools import json2csv
from retriever.lib.engine_tools import xml2csv
from retriever.lib.warning import Warning


Expand Down Expand Up @@ -605,6 +606,15 @@ def process_json2csv(self, src_path, path_to_csv, headers, encoding=ENCODING):
encoding=encoding,
row_key=None)

def process_xml2csv(self,
src_path,
path_to_csv,
header_values=None,
empty_rows=1,
encoding=ENCODING):
if self.find_file(src_path):
xml2csv(src_path, path_to_csv, header_values, empty_rows, encoding)

def extract_gz(
self,
archive_path,
Expand Down
54 changes: 51 additions & 3 deletions retriever/lib/engine_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from hashlib import md5
from io import StringIO as NewFile
from retriever.lib.defaults import HOME_DIR, ENCODING

import xml.etree.ElementTree as ET
import os
Expand All @@ -30,9 +29,11 @@
from pandas.io.json import json_normalize
from collections import OrderedDict

warnings.filterwarnings("ignore")
from retriever.lib.defaults import HOME_DIR, ENCODING
from retriever.lib.tools import open_fr, open_csvw, open_fw

warnings.filterwarnings("ignore")

TEST_ENGINES = dict()


Expand Down Expand Up @@ -225,7 +226,7 @@ def sqlite2csv(input_file, output_file, table_name=None, encoding=ENCODING):
return output_file


def xml2csv(input_file, outputfile=None, header_values=None, row_tag="row"):
def xml2csv_test(input_file, outputfile=None, header_values=None, row_tag="row"):
"""Convert xml to csv.
Function is used for only testing and can handle the file of the size.
Expand Down Expand Up @@ -377,3 +378,50 @@ def set_proxy():
for i in proxies:
os.environ[i] = os.environ[proxy]
break


def xml2dict(data, node, level):
"""Convert xml to dict type.
"""
vals = dict()
for child in node:
key = child.tag.strip()
if key not in data:
data[key] = []
if child.attrib:
if key not in vals:
vals[key] = [child.attrib]
else:
vals[key].append(child.attrib)
if child.text and child.text.strip():
if key not in vals:
vals[key] = [child.text]
else:
vals[key].append(child.text)
if child:
xml2dict(data, child, level + 1)

for k in vals:
if len(vals) == 1:
for val in vals[k]:
data[k].append(val)
else:
val = vals[k] if len(vals[k]) > 1 else vals[k][0]
data[k].append(val)


def xml2csv(input_file, output_file, header_values=None, empty_rows=1, encoding=ENCODING):
"""Convert xml to csv."""

tree = ET.parse(input_file)
root = tree.getroot()
dic = OrderedDict()
xml2dict(dic, root, empty_rows)

for empty_row in range(empty_rows):
dic.pop("row")
df = pd.DataFrame.from_dict(dic, orient='index')
df = df.transpose()
df.to_csv(output_file, index=False, encoding=encoding)
return output_file
2 changes: 1 addition & 1 deletion retriever/lib/load_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def read_json(json_file):

if isinstance(json_object, dict) and "resources" in json_object.keys():
# Note::formats described by frictionless data may need to change
tabular_exts = {"csv", "tab", "geojson", "sqlite", "db", "json"}
tabular_exts = {"csv", "tab", "geojson", "sqlite", "db", "json", "xml"}
vector_exts = {"shp", "kmz"}
raster_exts = {"tif", "tiff", "bil", "hdr", "h5", "hdf5", "hr", "image"}
for resource_item in json_object["resources"]:
Expand Down
14 changes: 14 additions & 0 deletions retriever/lib/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,20 @@ def process_tables(self, table_obj, url):

self.engine.process_json2csv(src_path, path_to_csv, schema_fields)

if hasattr(table_obj, "xml_data"):
src_path = self.engine.format_filename(table_obj.xml_data)
path_to_csv = self.engine.format_filename(table_obj.path)
self.engine.download_file(url, table_obj.xml_data)
schema_fields = None
empty_rows = 1
if hasattr(table_obj, "empty_rows"):
empty_rows = table_obj.empty_rows
if hasattr(table_obj, "schema") and hasattr(table_obj.schema, "fields"):
if table_obj.schema.fields:
schema_fields = table_obj.schema.fields

self.engine.process_xml2csv(src_path, path_to_csv, schema_fields, empty_rows)

if hasattr(table_obj, "path"):
self.engine.auto_create_table(table_obj, url=url, filename=table_obj.path)
else:
Expand Down
2 changes: 1 addition & 1 deletion scripts/mammal_masses.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
{
"dialect": {
"missingValues": [
-999
-999, "None"
],
"header_rows": 0
},
Expand Down
10 changes: 5 additions & 5 deletions test/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from distutils.dir_util import copy_tree
from imp import reload

import pytest
import retriever as rt
from retriever.engines import engine_list
from retriever.lib.defaults import DATA_DIR
from retriever.lib.load_json import read_json

import pytest
from retriever.lib.engine_tools import getmd5
from retriever.engines import engine_list
from retriever.lib.load_json import read_json

# Set postgres password, Appveyor service needs the password given
# The Travis service obtains the password from the config file.
Expand Down Expand Up @@ -55,6 +54,7 @@
('bird_size', '98dcfdca19d729c90ee1c6db5221b775'),
('mammal_masses', '6fec0fc63007a4040d9bbc5cfcd9953e'),
('portal-project-teaching', 'f81620d5f5550b81062e427542e96fa5'),
('county-emergency-management-offices', '75fcadc47cf38f3650a7686e074c7211'),
('nuclear-power-plants', 'b932543c4fb311357a9616a870226a6b')
]

Expand Down Expand Up @@ -222,7 +222,7 @@ def test_mysql_regression(dataset, expected, tmpdir):

# xml_engine is failing for nuclear-power-plants
# dataset as it contains a special character
@pytest.mark.parametrize("dataset, expected", db_md5[:4])
@pytest.mark.parametrize("dataset, expected", db_md5[:5])
def test_xmlengine_regression(dataset, expected, tmpdir):
"""Check for xmlenginee regression."""
xml_engine.opts = {
Expand Down
38 changes: 19 additions & 19 deletions test/test_retriever.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
"""Tests for the Data Retriever"""
import os
import subprocess
import random
import subprocess

import pytest
import requests

import retriever as rt
from retriever.lib.engine import Engine
from retriever.lib.table import TabularDataset
from retriever.lib.templates import BasicTextTemplate
from retriever.lib.cleanup import correct_invalid_value
from retriever.lib.engine import Engine
from retriever.lib.engine_tools import getmd5
from retriever.lib.engine_tools import xml2csv
from retriever.lib.engine_tools import json2csv
from retriever.lib.engine_tools import xml2csv_test
from retriever.lib.table import TabularDataset
from retriever.lib.templates import BasicTextTemplate

try:
from retriever.lib.engine_tools import geojson2csv
Expand Down Expand Up @@ -50,6 +50,10 @@
("null_data_json", ["""[{"User":"Alex","id":"US1","Age":"25","kt":"2.0","qt":"1.00"},{"User":"Tom","id":"US2","Age":"20","kt":"0.0","qt":"1.0"},{"User":"Dan","id":"44","Age":"2","kt":"0","qt":"1"},{"User":"Kim","id":"654","Age":"","kt":"","qt":""}]"""], ["User", "id", "Age", "kt", "qt"], None, ['User,id,Age,kt,qt', 'Alex,US1,25,2.0,1.00', 'Tom,US2,20,0.0,1.0', 'Dan,44,2,0,1', 'Kim,654,,,'])
]

xml2csv_dataset = [
("simple_xml", ["""<root><row><User>Alex</User><Country>US</Country><Age>25</Age></row><row><User>Ben</User><Country>US</Country><Age>24</Age></row></root>"""], ["User", "Country", "Age"], 1, ['User,Country,Age', 'Alex,US,25', 'Ben,US,24'])
]

# Main paths
HOMEDIR = os.path.expanduser('~')
file_location = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -583,25 +587,21 @@ def test_sqlite2csv(test_name, db_name, sqlite_data_url, table_name, expected):
os.remove(db_name)
assert header_val == expected

def test_xml2csv():
@pytest.mark.parametrize("test_name, xml_data, header_values, empty_rows, expected", xml2csv_dataset)
def test_xml2csv(test_name, xml_data, header_values, empty_rows, expected):
"""Test xml2csv function.
Creates a xml file and tests the md5 sum calculation.
"""
xml_file = create_file(['<root>', '<row>',
'<User>Alex</User>',
'<Country>US</Country>',
'<Age>25</Age>', '</row>',
'<row>', '<User>Ben</User>',
'<Country>US</Country>',
'<Age>24</Age>',
'</row>', '</root>'], 'output.xml')

output_xml = xml2csv(xml_file, "output_xml.csv",
header_values=["User", "Country", "Age"])
xml_file = create_file(xml_data, 'output.xml')
input_file = xml_file
outputfile = "output_xml.csv"

output_xml = xml2csv_test(input_file, outputfile, header_values, row_tag="row")

obs_out = file_2list(output_xml)
os.remove(output_xml)
assert obs_out == ['User,Country,Age', 'Alex,US,25', 'Ben,US,24']
assert obs_out == expected


def test_sort_file():
Expand Down

0 comments on commit b3a62e2

Please sign in to comment.