Skip to content

Commit

Permalink
Improved docstrings, tests passing, added deps to conda in travis scr…
Browse files Browse the repository at this point in the history
…ipt, made pyrate repo and algos always load
  • Loading branch information
willu47 committed Mar 7, 2016
1 parent 3df073a commit 0bd4c3b
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 50 deletions.
81 changes: 78 additions & 3 deletions pyrate/algorithms/aisparser.py
Expand Up @@ -19,7 +19,7 @@ def parse_timestamp(s):
def int_or_null(s):
if len(s) == 0:
return None
else:
else:
return int(s)

def float_or_null(s):
Expand All @@ -39,6 +39,17 @@ def longstr(s):
return s

def set_null_on_fail(row, col, test):
""" Helper function which sets the column in a row of data to null on fail
Arguments
---------
row : dict
A dictionary of the fields
col : str
The column to check
test : func
One of the validation functions in pyrate.utils
"""
if not row[col] == None and not test(row[col]):
row[col] = None

Expand Down Expand Up @@ -110,8 +121,22 @@ def xml_name_to_csv(name):
return AIS_CSV_COLUMNS[AIS_XML_COLNAMES.index(name)]

def parse_raw_row(row):
"""Parse values from row, returning a new dict with values
converted into appropriate types. Throw an exception to reject row"""
"""Parse values from row, returning a new dict with converted values
Parse values from row, returning a new dict with converted values
converted into appropriate types. Throw an exception to reject row
Arguments
---------
row : dict
A dictionary of headers and values from the csv file
Returns
-------
converted_row : dict
A dictionary of headers and values converted using the helper functions
"""
converted_row = {}
converted_row[MMSI] = int_or_null(row[MMSI])
converted_row[TIME] = parse_timestamp(row[TIME])
Expand Down Expand Up @@ -256,6 +281,36 @@ def sqlworker(q, table):
logging.info("Finished building indices, time elapsed = %fs", time.time() - start)

def parse_file(fp, name, ext, baddata_logfile, cleanq, dirtyq, source=0):
""" Parses a file containing AIS data
Arguments
---------
fp : str
Filepath of file to be parsed
name : str
Name of file to be parsed
ext : str
Extension, either '.csv' or '.xml'
baddata_logfile : str
Name of the logfile
cleanq :
Queue for messages to be inserted into clean table
dirtyq :
Queue for messages to be inserted into dirty table
source : int, optional, default=0
0 is satellite, 1 is terrestrial
Returns
-------
invalid_ctr : int
Number of invalid rows
clean_ctr : int
Number of clean rows
dirty_ctr : int
Number of dirty rows
time_elapsed : time
The time elapsed since starting the parse_file procedure
"""
filestart = time.time()
logging.info("Parsing "+ name)

Expand Down Expand Up @@ -316,6 +371,26 @@ def parse_file(fp, name, ext, baddata_logfile, cleanq, dirtyq, source=0):
return (invalid_ctr, clean_ctr, dirty_ctr, time.time() - filestart)

def readcsv(fp):
""" Returns a dictionary of the subset of columns required
Reads each line in CSV file, checks if all columns are available,
and returns a dictionary of the subset of columns required
(as per AIS_CSV_COLUMNS).
If row is invalid (too few columns),
returns an empty dictionary.
Arguments
---------
fp : str
File path
Yields
------
rowsubset : dict
A dictionary of the subset of columns as per `columns`
"""
# fix for large field error. Specify max field size to the maximum convertable int value.
# source: http://stackoverflow.com/questions/15063936/csv-error-field-larger-than-field-limit-131072
max_int = sys.maxsize
Expand Down
51 changes: 27 additions & 24 deletions pyrate/algorithms/vesselimporter.py
@@ -1,3 +1,6 @@
""" Extracts a subset of clean ships into ais_extended tables
"""
import logging
import time
import threading
Expand Down Expand Up @@ -35,9 +38,9 @@ def filter_intervals(interval):

def filter_good_ships(aisdb):
"""Generate a set of imo numbers and (mmsi, imo) validity intervals
Generate a set of imo numbers and (mmsi, imo) validity intervals
for ships which are deemed to be 'clean'.
for ships which are deemed to be 'clean'.
A clean ship is defined as one which:
* Has valid MMSI numbers associated with it.
* For each MMSI number, the period of time it is associated with this IMO
Expand All @@ -52,15 +55,15 @@ def filter_good_ships(aisdb):
valid_imos :
A set of valid imo numbers
imo_mmsi_intervals :
A list of (mmsi, imo, start, end) tuples, describing the validity
A list of (mmsi, imo, start, end) tuples, describing the validity
intervals of each (mmsi, imo) pair
"""

with aisdb.conn.cursor() as cur:
cur.execute("SELECT distinct imo from {}".format(aisdb.imolist.get_name()))
imo_list = [row[0] for row in cur.fetchall() if valid_imo(row[0])]
logging.info("Checking %d IMOs", len(imo_list))

valid_imos = []
imo_mmsi_intervals = []

Expand All @@ -85,9 +88,9 @@ def filter_good_ships(aisdb):
if last_end != None and start < last_end:
valid = False
#logging.info("IMO: %s, overlapping MMSI intervals", imo)
break;
break
last_end = end

if valid:
# check for other users of this mmsi number
mmsi_list = [row[0] for row in mmsi_ranges]
Expand All @@ -103,14 +106,14 @@ def filter_good_ships(aisdb):
else:
pass
#logging.info("IMO: %s, reuse of MMSI", imo)

return (valid_imos, imo_mmsi_intervals)

def cluster_table(aisdb, table):
"""Performs a clustering of the postgresql table on the MMSI index.
"""Performs a clustering of the postgresql table on the MMSI index.
This process significantly improves the runtime of extended table generation.
"""
with aisdb.conn.cursor() as cur:
index_name = table.name.lower() + "_mmsi_idx"
Expand All @@ -120,8 +123,8 @@ def cluster_table(aisdb, table):

def generate_extended_table(aisdb, intervals, n_threads=2):

logging.info("Inserting %d squeaky clean MMSIs", len(intervals))
logging.info("Inserting %d squeaky clean MMSIs", len(intervals))

start = time.time()

interval_q = queue.Queue()
Expand Down Expand Up @@ -189,34 +192,34 @@ def insert_message_stream(aisdb, interval, msg_stream):
aisdb.extended.insert_rows_batch(valid + artificial)

# mark the work we've done
aisdb.action_log.insert_row({'action': "import",
aisdb.action_log.insert_row({'action': "import",
'mmsi': mmsi,
'ts_from': start,
'ts_from': start,
'ts_to': end,
'count': len(valid)})
aisdb.action_log.insert_row({'action': "outlier detection (noop)",
'mmsi': mmsi,
'ts_from': start,
aisdb.action_log.insert_row({'action': "outlier detection (noop)",
'mmsi': mmsi,
'ts_from': start,
'ts_to': end,
'count': len(invalid)})
aisdb.action_log.insert_row({'action': "interpolation (noop)",
'mmsi': mmsi,
'ts_from': start,
aisdb.action_log.insert_row({'action': "interpolation (noop)",
'mmsi': mmsi,
'ts_from': start,
'ts_to': end,
'count': len(artificial)})
upsert_interval_to_imolist(aisdb, mmsi, imo, start, end)

def get_remaining_interval(aisdb, mmsi, imo, start, end):
with aisdb.conn.cursor() as cur:
try:
cur.execute("SELECT tsrange(%s, %s) - tsrange(%s, %s) * tsrange(first_seen - interval '1 second', last_seen + interval '1 second') FROM {} WHERE mmsi = %s AND imo = %s".format(aisdb.clean_imolist.name),
cur.execute("SELECT tsrange(%s, %s) - tsrange(%s, %s) * tsrange(first_seen - interval '1 second', last_seen + interval '1 second') FROM {} WHERE mmsi = %s AND imo = %s".format(aisdb.clean_imolist.name),
[start, end, start, end, mmsi, imo])
row = cur.fetchone()
if not row is None:
sub_interval = row[0]
if sub_interval.isempty:
return None
else:
else:
return sub_interval.lower, sub_interval.upper
else:
return (start, end)
Expand All @@ -232,10 +235,10 @@ def upsert_interval_to_imolist(aisdb, mmsi, imo, start, end):
count = cur.fetchone()[0]
if count == 1:
cur.execute("""UPDATE {} SET
first_seen = LEAST(first_seen, %s),
first_seen = LEAST(first_seen, %s),
last_seen = GREATEST(last_seen, %s)
WHERE mmsi = %s AND imo = %s""".format(aisdb.clean_imolist.name),
[start, end, mmsi, imo])
elif count == 0:
aisdb.clean_imolist.insert_row({'mmsi': mmsi, 'imo': imo, 'first_seen': start,
aisdb.clean_imolist.insert_row({'mmsi': mmsi, 'imo': imo, 'first_seen': start,
'last_seen': end})
31 changes: 15 additions & 16 deletions pyrate/cli.py
Expand Up @@ -11,49 +11,48 @@ def main():
"""
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# load tool components
config = loader.DEFAULT_CONFIG
config.read(['aistool.conf'])
config.read(['default.conf'])
l = loader.Loader(config)

def list_components(args):
print("{} repositories:".format(len(l.get_data_repositories())))
for r in l.get_data_repositories():
print("\t"+r)
for repository in l.get_data_repositories():
print("\t" + repository)

print("{} algorithms:".format(len(l.get_algorithms())))

for a in l.get_algorithms():
print("\t"+a)

for algorithm in l.get_algorithms():
print("\t" + algorithm)

def execute_repo_command(args):
l.execute_repository_command(args.repo, args.cmd)

def execute_algorithm(args):
l.execute_algorithm_command(args.alg, args.cmd)

# set up command line parser
parser = argparse.ArgumentParser(description="AIS Super tool")
subparsers = parser.add_subparsers(help='available commands')

parser_list = subparsers.add_parser('list', help='list loaded data repositories and algorithms')
parser_list.set_defaults(func=list_components)

for r in l.get_data_repositories():
repo_parser = subparsers.add_parser(r, help='commands for '+ r +' repository')
repo_subparser = repo_parser.add_subparsers(help=r+' repository commands.')
for cmd, desc in l.get_repository_commands(r):
cmd_parser = repo_subparser.add_parser(cmd, help=desc)
cmd_parser.set_defaults(func=execute_repo_command, cmd=cmd, repo=r)

for a in l.get_algorithms():
alg_parser = subparsers.add_parser(a, help='commands for algorithm '+ a +'')
alg_subparser = alg_parser.add_subparsers(help=a+' algorithm commands.')
for cmd, desc in l.get_algorithm_commands(a):
alg_parser = alg_subparser.add_parser(cmd, help=desc)
alg_parser.set_defaults(func=execute_algorithm, cmd=cmd, alg=a)

args = parser.parse_args()
if 'func' in args:
args.func(args)
Expand Down
4 changes: 3 additions & 1 deletion pyrate/loader.py
Expand Up @@ -33,7 +33,7 @@ def load_all_modules(paths):
DEFAULT_CONFIG.set('globals', 'algos', 'pyrate/algorithms')

class Loader:
"""The Loader joins together data repositories and algorithms,
"""The Loader joins together data repositories and algorithms,
and executes operations on them."""

def __init__(self, config):
Expand All @@ -45,6 +45,7 @@ def __init__(self, config):

repopaths = str(config.get('globals', 'repos'))
repopaths = repopaths.split(',')
repopaths.extend('pyrate/repositories')

# load repo drivers from repopaths
repo_drivers = load_all_modules(repopaths)
Expand All @@ -65,6 +66,7 @@ def __init__(self, config):

algopaths = str(config.get('globals', 'algos'))
algopaths = algopaths.split(',')
algopaths.extend('pyrate/algorithms')

# load algorithms from algopaths
algorithms = load_all_modules(algopaths)
Expand Down
5 changes: 2 additions & 3 deletions pyrate/repositories/file.py
Expand Up @@ -45,7 +45,7 @@ def status(self):

def iterfiles(self):
"""
Iterate files in this file repository. Returns a generator of 3-tuples,
Iterate files in this file repository. Returns a generator of 3-tuples,
containing a handle, filename and file extension of the current opened file.
"""
logging.debug("Iterating files in "+ self.root)
Expand All @@ -55,7 +55,6 @@ def iterfiles(self):
for filename in files:
_, ext = os.path.splitext(filename)
if self.allowed_extensions == None or ext in self.allowed_extensions:
# hitting errors with decoding the data, utf-8 seems to sort it
with open(os.path.join(root, filename), 'r', encoding='utf-8') as fp:
yield (fp, filename, ext)
# zip file auto-extract
Expand All @@ -66,7 +65,7 @@ def iterfiles(self):
_, ext = os.path.splitext(zname)
if self.allowed_extensions == None or ext in self.allowed_extensions:
with z.open(zname, 'r') as fp:
# zipfile returns a binary file, so we require a
# zipfile returns a binary file, so we require a
# TextIOWrapper to decode it
yield (io.TextIOWrapper(fp, encoding='ascii'), zname, ext)
except (zipfile.BadZipFile, RuntimeError) as error:
Expand Down
2 changes: 2 additions & 0 deletions pyrate/repositories/sql.py
Expand Up @@ -118,8 +118,10 @@ def insert_rows_batch(self, rows):
# check there are rows in insert
if len(rows) == 0:
return
# logging.debug("Row to insert: {}".format(rows[0]))
with self.db.conn.cursor() as cur:
columnlist = '(' + ','.join([c.lower() for c in rows[0].keys()]) + ')'
# logging.debug("Using columns: {}".format(columnlist))
tuplestr = "(" + ",".join("%({})s".format(i) for i in rows[0]) + ")"
# create a single query to insert list of tuples
# note that mogrify generates a binary string which we must first decode to ascii.
Expand Down

0 comments on commit 0bd4c3b

Please sign in to comment.