In [None]:
from abc import ABC, abstractmethod
import attr
from enum import Enum
import ipywidgets as widgets
import os
import pickle
import plotly.express as px
import plotly.graph_objects as go

from datetime import date, datetime, timedelta
from typing import List, Optional

In [None]:
@attr.s
class Activity:
    category: str = attr.ib()  # calendar
    start: datetime = attr.ib()
    end: datetime = attr.ib()
    name: Optional[str] = attr.ib(None)
    description: Optional[str] = attr.ib(None)

    @property
    def duration(self):
        "Return the duration of the activity in seconds."
        diff = self.end - self.start
        return diff.total_seconds()

In [None]:
class TimePeriodType(Enum):
    """Time periods that can be analyzed."""

    DAILY = "daily"
    WEEKLY = "weekly"
    OVERALL = "overall"

    @classmethod
    def all_to_list(cls):
        """Return a list of all available time periods."""
        return [cls.DAILY.value, cls.WEEKLY.value, cls.OVERALL.value]


@attr.s
class TimePeriod:
    """Representation of a time period."""

    start: datetime = attr.ib()
    end: datetime = attr.ib()
    type: TimePeriodType = attr.ib()


class TimePeriods:
    """A set of multiple time periods."""

    def __init__(self, start, end, type):
        """Generate time periods between a `start` and `end` date based on the period type."""
        if type == TimePeriodType.DAILY:
            date_list = [
                start + timedelta(days=x) for x in range((end - start).days + 1)
            ]
            self.periods = [TimePeriod(d, d, type) for d in date_list]
        elif type == TimePeriodType.WEEKLY:
            self.periods = []
            current_date = start
            while current_date < end:
                day_diff = 6 - current_date.weekday()
                self.periods.append(
                    TimePeriod(
                        current_date, current_date + timedelta(days=day_diff), type
                    )
                )
                current_date = current_date + timedelta(days=day_diff + 1)
            self.periods.append(
                TimePeriod((self.periods[-1].end + timedelta(days=1)), end, type)
            )
        elif type == TimePeriodType.OVERALL:
            self.periods = [TimePeriod(start, end, type)]
        else:
            raise Exception(f"Unknown time period type {type}")

    def get_period(self, date):
        """Return the time period a date belongs to."""
        for period in self.periods:
            if date >= period.start and date <= period.end:
                return period

        return None

In [None]:
class AggregationType(Enum):
    """The way data should be aggregated and displayed."""

    TOTAL = "total"
    PERCENTAGE = "percentage"

    @classmethod
    def all_to_list(cls):
        """Return all available aggregation types."""
        return [cls.TOTAL.value, cls.PERCENTAGE.value]

In [None]:
@attr.s
class ActivityVisualization(ABC):
    """Implements a visualization for calendar data."""

    activities: List[Activity] = list()
    start: datetime = date.today()
    end: datetime = date.today()
    calendars: List[str] = list()
    period: TimePeriodType = TimePeriodType.DAILY
    aggregation: AggregationType = AggregationType.TOTAL

    @property
    def title(self):
        return "Visualization"

    @property
    def description(self):
        return None

    def get_periods(self):
        return TimePeriods(self.start, self.end, TimePeriodType(self.period))

    def refresh(self, activities, start, end, calendars, period, aggregation):
        self.activities = activities
        self.start = start
        self.end = end
        self.calendars = calendars
        self.period = period
        self.aggregation = aggregation

        self.process()
        self.aggregate_data()
        self.plot()

    def aggregate_data(self):
        """Transform the data based on the aggregation."""
        dates = [d.start for d in self.get_periods().periods]

        if self.aggregation == AggregationType.PERCENTAGE.value:
            date_totals = {}
            for c, data in self.data.items():
                for i, val in enumerate(data):
                    if dates[i] not in date_totals:
                        date_totals[dates[i]] = val
                    else:
                        date_totals[dates[i]] += val

            data = {
                c: [
                    v / date_totals[dates[i]] * 100.0 if date_totals[dates[i]] else 0
                    for i, v in enumerate(data)
                ]
                for c, data in self.data.items()
            }

            self.data = data

    def plot(self):
        """
        Plot the visualization.

        Uses a stacked bar chart by default.
        """

        self.data["Date"] = [d.start for d in self.get_periods().periods]

        if len(list(self.activities.keys())) > 0:
            fig = px.bar(
                self.data,
                x="Date",
                y=list(self.data.keys()),
                labels={"value": "Time spent"},
                color_discrete_sequence=px.colors.qualitative.Dark24,
            )
            fig.update_layout(
                title=go.layout.Title(
                    text=f"{self.title} <br><sup>{self.description}</sup>",
                    xref="paper",
                    x=0,
                )
            )
            fig.layout.template = "plotly_dark"
            fig.show()
        else:
            layout = go.Layout(
                height=100,
                width=300,
                annotations=[
                    go.layout.Annotation(
                        text="No data to display",
                        xref="paper",
                        yref="paper",
                        font={"family": "Courier"},
                    )
                ],
            )
            fig = go.FigureWidget(data=[{"y": [2, 3, 1]}], layout=layout)
            fig.show()

        @abstractmethod
        def process(self):
            """Process the data."""
            raise NotImplementedError