Skip to content

Commit

Permalink
updated plotter to use sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
thigg committed Mar 20, 2024
1 parent 946d1f1 commit 00a033f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 117 deletions.
165 changes: 48 additions & 117 deletions plotter/main.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,99 @@
import argparse
import functools
import itertools
import json
import logging
import lzma
import os
import sqlite3
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
from typing import Dict, List

import brotli
import matplotlib.dates as mdates
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='accumulate fahrpreis data')
parser.add_argument('--accufile',
help='file with latest accu data (/tmp/fahrpreise_akku) usefull for working on plotting')
parser.add_argument('start_station',
help='start station id')
parser.add_argument('end_station', help='end station id')
parser.add_argument('--dbfile', help='path to the sqlitefile with the data', required=True)
parser.add_argument('--start_station', help='start station id', required=True)
parser.add_argument('--end_station', help='end station id', required=True)
parser.add_argument('--plot_timeframe_past', help='oldest travel start date on the plot, days relative to now',
default=60)
parser.add_argument('--plot_timeframe_future', help='newest travel start date on the plot, days relative to now',
default=10)
parser.add_argument('--plot_timeframe_date', help='what now is for timeframe',
default=datetime.now())

args = parser.parse_args()


def handle_file(path: str):
"""
reads a record from the crawler and transofrms it into a list with connections and prices
:param path: the file to read
:return: a list of connections with prices and when the price was queried
"""
try:
result = list()
with open(path, 'rb') as in_file:
decompressor = brotli.Decompressor()
s: str = ''
read_chunk = functools.partial(in_file.read, )
for data in iter(read_chunk, b''):
s += bytes.decode(decompressor.process(data), 'utf-8')
dict = json.loads(s)
queried_at = dict['queried_at']
for day in dict['data']:
for travel in day:
price = travel['price']['amount']
start_station = travel['legs'][0]['origin']['id']
start_time = travel['legs'][0]['departure']
end_station = travel['legs'][-1]['origin']['id']
end_time = travel['legs'][-1]['departure']
dict_key = "$".join([start_station, start_time, end_station, end_time])
result.append((dict_key, {"queried_at": queried_at, "price": price}))
return result
except Exception as e:
print("Could not read file %s, %s, size: %d" % (path, e, os.stat(path).st_size))
return []

def accumulate_sqlite(filename: str, start_station: int, end_station: int, timeframe_start: datetime,
timeframe_end: datetime) -> dict[str, list[tuple[int, int]]]:
conn = sqlite3.connect(filename)
cursor = conn.cursor()

def accumulate_data() -> dict[str, list[dict[str, object]]]:
"""
preprocesses the raw data into the data we need for the plot
:return: a dictionary of all travels (a connection at a specific datetime) with a list of prices and when the price was queried
"""
starttime = datetime.now()
result: dict[str, list[dict[str, object]]] = defaultdict(list)
with os.scandir("/tmp/fahrpreise/") as dirIterator:
with ProcessPoolExecutor() as executor:
resultlist = list(
itertools.chain.from_iterable(
executor.map(handle_file, (str(entry.path) for entry in dirIterator if
entry.name.endswith('.brotli') and entry.is_file()))))
print(f"got resultlist {len(resultlist)}")
for item in resultlist:
if item:
key, value = item
result[key].append(value)
# Create a dictionary to store accumulated prices
prices_dict = defaultdict(lambda: [])

print(f"accumulation took {datetime.now() - starttime}")
return result
# Execute SQL query to fetch data from the table
cursor.execute("SELECT `when`, `price_cents`, `queried_at` FROM fahrpreise "
f"where `from` = {start_station} and `to` = {end_station} "
f"and `queried_at` >= {round(timeframe_start.timestamp() * 1000)} "
f"and `queried_at` <= {round(timeframe_end.timestamp() * 1000)}")

# Fetch all rows and accumulate prices
rows = cursor.fetchall()
for row in rows:
when = row[0]
price = row[1]
queried_at = row[2]
prices_dict[when].append((int(queried_at), int(price)))

isoformatstr = "%Y-%m-%dT%H:%M:%S.%fZ"
conn.close()
return prices_dict


def plot(result: Dict[str, List[Tuple[str, str]]], start_station_filter: str, end_station_filter: str,
starttime_after: datetime, starttime_before):
def plot(result: Dict[str, List[tuple[int, int]]]):
"""
:param result: the data to ingest for the plot
:param start_station_filter: which start station to consider
:param end_station_filter: which end station to consider
:param starttime_after: timeframe to plot lower limit
:param starttime_before: timeframe to plot upper limit
:return: shows the plot
"""
print(f"creating filtered plot for stations: {start_station_filter}, {end_station_filter}")
time_to_departure = []
# booking_date = []
departure_date = []
# y_axis2 = []
travel_price = []
travel_price_euro = []
# z_axis2 = []
recorded_connections: dict[tuple[str, str], int] = defaultdict(lambda: 0)
for i, travelprices in enumerate(result.items()):
try:
keystr = travelprices[0]
keystr_split = keystr.split("$")
start_station = keystr_split[0]
end_station = keystr_split[2]
recorded_connections[(start_station, end_station)] += 1
if start_station != start_station_filter or end_station != end_station_filter:
# print(f"skipping '{start_station}'({type(start_station)}) '{end_station}' because filters '{start_station_filter}'({type(start_station_filter)}) '{end_station_filter}'")
continue
starttime = datetime.strptime(keystr_split[1], isoformatstr)
endtime = datetime.strptime(keystr_split[3], isoformatstr)
when, price_records = travelprices
starttime = datetime.fromtimestamp(int(when) / 1000)
# z_axis2.append((endtime-starttime).total_seconds())
# y_axis2.append(starttime)
for data in travelprices[1]:
if starttime > starttime_after and starttime < starttime_before:
time_to_departure.append(
-(starttime - datetime.strptime(data["queried_at"], isoformatstr)).total_seconds() / (
60 * 60 * 24))
# booking_date.append(datetime.strptime(data["queried_at"], isoformatstr))
departure_date.append(starttime)
travel_price.append(data["price"])
price_record: Dict[str, int]
for queried_at, price in price_records:
queried_at_date = datetime.fromtimestamp(queried_at / 1000)
days_to_departure = (starttime - queried_at_date).total_seconds() / (60 * 60 * 24)
time_to_departure.append(-days_to_departure)
departure_date.append(starttime)
travel_price_euro.append(price / 100)
except:
logging.exception("exception while prepping plot %d %s", i, travelprices)
print("recorded connections: " + str(recorded_connections.items()))
print("resulting datapoints: %d" % len(travel_price))
print("resulting datapoints for filter: %d" % recorded_connections[(start_station_filter, end_station_filter)])
color_map = plt.cm.plasma
plt.rcParams['figure.figsize'] = [30, 50]
print("resulting datapoints: %d" % len(travel_price_euro))
plt.rcParams['figure.figsize'] = [12, 20]
fig, ax = plt.subplots(1)
prices = ax
price_records = ax
# plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%a %d.%m'))
plt.gca().yaxis.set_major_formatter(mdates.DateFormatter('%a %d.%m'))
# plt.gca().xaxis.set_major_locator(mdates.DayLocator())
plt.gca().yaxis.set_major_locator(mdates.DayLocator())
pcm = prices.scatter(time_to_departure, departure_date, c=travel_price, cmap=color_map, marker=".", s=1, vmax=75)
pcm = price_records.scatter(time_to_departure, departure_date, c=travel_price_euro, cmap=plt.colormaps["plasma"], marker=".", s=1,
vmax=75)
# fig.autofmt_xdate()
# prices.twinx().barh(y_axis2,z_axis2,height=0.1)
fig.colorbar(pcm, label="price", ax=prices)
fig.colorbar(pcm, label="price (Euro)", ax=price_records)
plt.gca().xaxis.set_label_text("tage bis abfahrt / wann ich buche")
plt.gca().yaxis.set_label_text("datum der reise / wann ich fahren möchte")
plt.show()


result = {}
if args.accufile:
print("reading data from tmpfile")
with lzma.open(args.accufile, "rt") as infile:
result = json.load(infile)
print("done reading infile")
else:
starttime = datetime.now()
result = accumulate_data()
startcompresstime = datetime.now()
with lzma.open("/tmp/fahrpreise_akku", "wt", preset=4) as outfile:
print("writing accufile")
json.dump(result, outfile)
print(
f"wrote accufile. whole process took total={datetime.now() - starttime} akku={startcompresstime - starttime} write={datetime.now() - startcompresstime}")

plot(result, args.start_station, args.end_station, datetime.now() - timedelta(days=args.plot_timeframe_past),
datetime.now() + timedelta(days=args.plot_timeframe_future))
timeframe_start: datetime = datetime.now() - timedelta(days=args.plot_timeframe_past)
timeframe_end: datetime = datetime.now() + timedelta(days=args.plot_timeframe_future)
result = accumulate_sqlite(args.dbfile, int(args.start_station), int(args.end_station), timeframe_start, timeframe_end)
plot(result)
1 change: 1 addition & 0 deletions plotter/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
matplotlib

0 comments on commit 00a033f

Please sign in to comment.