In [1]:
from dash import dcc, Dash, dash_table, html
from src.db import Database
from dash.dependencies import Input, Output, State
from itertools import chain
import pandas as pd

In [2]:
def load_database(db_path: str) -> pd.DataFrame:
    """ Loads a serialized database from specified path to pickle file """
    database = Database()
    database.load(db_path)
    
    return database

def get_metadata(database: pd.DataFrame) -> dict:
    """ Extracts metadata from database """
    metadata = {
        "creation_time": database.creation_time,
        "tag_version"  : database.tag_version
    }
    
    return metadata

def get_tags(database: pd.DataFrame) -> list:
    """ Extracts unique tags from database """
    database.db["tag"] = database.db["tag"].apply(lambda lst: [s.lower() for s in lst])
    db_exploded = database.db.explode("tag")
    mapping = db_exploded.groupby("tag")["doi"].apply(list).to_dict()
    tags = list(mapping.keys())
    
    return mapping, tags

def prepare_database(database: pd.DataFrame) -> pd.DataFrame:
    """ Drop tag column for database and rename remaining columns """
    database = database.db.drop('tag', axis=1)
    database.columns = ["Title", "Authors", "Journal", "Year", "DOI"]
    
    return database

In [3]:
def create_layout(app, database: pd.DataFrame, metadata: dict, mapping: dict, tags: list):
    
    creation_time = metadata["creation_time"]
    tag_version = metadata["tag_version"]

    app.layout = html.Div([
        
        # Heading
        html.H1(children='A tagged database of active inference, predictive processing, and free energy principle papers'),
        
        dcc.Tabs([
            dcc.Tab(label="Database Table", children=[
            
                # Store
                dcc.Store(id="filtered_data_store"),
                
                html.Br(),
                
                html.P(children=f"Database creation time: {creation_time}"),
                html.P(children=f"Tag version: {tag_version}"),
                
                html.Br(),
                
                # Download button
                html.Button("Download CSV", id="btn_csv"),
                dcc.Download(id="download-dataframe-csv"),
                
                # Tag filter menu
                dcc.Dropdown(
                    id="filter_dropdown",
                    # options=[{"label": tag, "value": tag} for tag in tags],
                    options=tags,
                    placeholder="Select a tag",
                    multi=True,
                    value=tags,
                ),
                
                # Main table
                dash_table.DataTable(
                    id="database",
                    columns=[
                        {"name": i, "id": i} for i in database.columns],
                    data=database.to_dict("records"),
                    row_deletable=True,
                    filter_action="native",
                    sort_mode="multi",
                    sort_action="native",
                    style_header = {
                        'text_align': 'left',
                        'backgroundColor' : 'rgb(30, 144, 255)'
                    },
                    style_data={
                        'whiteSpace': 'normal',
                        'height': 'auto',
                        'text_align': 'left'},
                    style_data_conditional=[
                        {'if': {'row_index': 'odd'},
                        'backgroundColor': 'rgb(240, 240, 240)'}
                    ],
                    style_cell_conditional=[
                        {'if': {'column_id': 'Title'},
                        'width': '350px'},
                        {'if': {'column_id': 'Authors'},
                        'width': '300px'},
                        {'if': {'column_id': 'Year'},
                        'width': '100px'}],
                    style_as_list_view=True
                )
            ]),
            
            dcc.Tab(label="Tag reference", children=[
                html.Div([
                    html.H3("Tag reference"),
                ])]),
            
            dcc.Tab(label="About This App", children=[
                html.Div([
                    html.H3("What is this?"),
                    html.P("This app lets you explore and download a curated list of papers related to active inference, predictive processing, and the free energy principle."),
                    html.P("Use the dropdown to filter by tags, and download the current view using the button."),
                    html.Br(),
                    html.P("For feature requests and bug reporting please visit the Github page: https://github.com/snamjoshi/aif-fep-db")
                ])])
        ])]
    )

    @app.callback(
        Output("database", "data"), 
        Output("filtered_data_store", "data"),
        Input("filter_dropdown", "value")
    )
    def display_table(tags):
        doi_list = [mapping[tag] for tag in tags]
        doi_list = list(chain.from_iterable(doi_list))
        doi_list = list(set(doi_list))
        
        filtered = database[database["DOI"].isin(doi_list)]
        filtered_dict = filtered.to_dict("records")
        return filtered_dict, filtered_dict

    @app.callback(
        Output("download-dataframe-csv", "data"),
        Input("btn_csv", "n_clicks"),
        State("filtered_data_store", "data"),
        prevent_initial_call=True,
    )
    def func(n_clicks, data):
        df = pd.DataFrame(data)
        return dcc.send_data_frame(df.to_csv, "filtered_papers.csv")
    
    return app

In [4]:
database = load_database("data/databases/database__2025-04-15__16:16:45.445422.pkl")
metadata = get_metadata(database=database)
mapping, tags = get_tags(database=database)
database = prepare_database(database=database)

app = Dash()
app = create_layout(app=app, database=database, metadata=metadata, mapping=mapping, tags=tags)
app.run(debug=True)

INFO:src.db:Database loaded from data/databases/database__2025-04-15__16:16:45.445422.pkl.


In [8]:
def get_tags(database: pd.DataFrame) -> list:
    """ Extracts unique tags from database """
    database.db["tag"] = database.db["tag"].apply(lambda lst: [s.lower() for s in lst])
    db_exploded = database.db.explode("tag")
    mapping = db_exploded.groupby("tag")["doi"].apply(list).to_dict()
    tags = list(mapping.keys())
    
    return mapping, tags

In [13]:
database = load_database("data/databases/database__2025-04-15__16:45:04.248461.pkl")
metadata = get_metadata(database=database)

INFO:src.db:Database loaded from data/databases/database__2025-04-15__16:45:04.248461.pkl.
