From ada634ba829b8f4c9af81f527665fd8d68b6c84a Mon Sep 17 00:00:00 2001 From: yangwenzhuo08 Date: Thu, 13 Apr 2023 16:30:09 +0800 Subject: [PATCH 1/2] Update copyright --- .pre-commit-config.yaml | 2 +- LICENSE | 2 +- pyrca/tools/__main__.py | 3 +- pyrca/tools/dashboard/callbacks/__init__.py | 5 +- pyrca/tools/dashboard/callbacks/causal.py | 131 +++++------- pyrca/tools/dashboard/callbacks/data.py | 59 ++---- pyrca/tools/dashboard/dashboard.py | 32 ++- pyrca/tools/dashboard/pages/__init__.py | 5 +- pyrca/tools/dashboard/pages/causal.py | 218 +++++++------------- pyrca/tools/dashboard/pages/data.py | 142 +++---------- pyrca/tools/dashboard/pages/utils.py | 47 ++--- pyrca/tools/dashboard/utils/__init__.py | 5 +- pyrca/tools/dashboard/utils/layout.py | 41 ++-- 13 files changed, 218 insertions(+), 474 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6a4ca57..81b6769 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: '19.3b0' + rev: '22.3.0' hooks: - id: black args: ["--line-length", "120"] diff --git a/LICENSE b/LICENSE index d3b1fa3..7043ca7 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2022, Salesforce +Copyright (c) 2023, Salesforce All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/pyrca/tools/__main__.py b/pyrca/tools/__main__.py index 354bd98..d4fcae4 100644 --- a/pyrca/tools/__main__.py +++ b/pyrca/tools/__main__.py @@ -2,8 +2,7 @@ # Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from pyrca.tools.dashboard.dashboard import app if __name__ == "__main__": diff --git a/pyrca/tools/dashboard/callbacks/__init__.py b/pyrca/tools/dashboard/callbacks/__init__.py index 599d47c..d0e4276 100644 --- a/pyrca/tools/dashboard/callbacks/__init__.py +++ b/pyrca/tools/dashboard/callbacks/__init__.py @@ -1,6 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/tools/dashboard/callbacks/causal.py b/pyrca/tools/dashboard/callbacks/causal.py index 545dc11..f5b3adb 100644 --- a/pyrca/tools/dashboard/callbacks/causal.py +++ b/pyrca/tools/dashboard/callbacks/causal.py @@ -1,19 +1,21 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import json import dash -from dash import html, Input, Output, State, \ - callback, no_update, dcc +from dash import html, Input, Output, State, callback, no_update, dcc from ..utils.file_manager import FileManager from ..pages.utils import create_param_table -from ..pages.causal import create_graph_figure, \ - create_causal_relation_table, create_cycle_table, \ - create_root_leaf_table, create_link_table +from ..pages.causal import ( + create_graph_figure, + create_causal_relation_table, + create_cycle_table, + create_root_leaf_table, + create_link_table, +) from ..models.causal import CausalDiscovery file_manager = FileManager() @@ -23,10 +25,7 @@ @callback( Output("causal-select-file", "options"), Output("select-domain", "options"), - [ - Input("causal-upload-data", "filename"), - Input("causal-upload-data", "contents") - ], + [Input("causal-upload-data", "filename"), Input("causal-upload-data", "contents")], ) def upload_file(filenames, contents): if filenames is not None and contents is not None: @@ -42,10 +41,7 @@ def upload_file(filenames, contents): return file_options, domain_options -@callback( - Output("select-causal-method", "options"), - Input("select-causal-method-parent", "n_clicks") -) +@callback(Output("select-causal-method", "options"), Input("select-causal-method-parent", "n_clicks")) def update_method_dropdown(n_clicks): options = [] ctx = dash.callback_context @@ -56,10 +52,7 @@ def update_method_dropdown(n_clicks): return options -@callback( - Output("causal-param-table", "children"), - Input("select-causal-method", "value") -) +@callback(Output("causal-param-table", "children"), Input("select-causal-method", "value")) def select_algorithm(algorithm): param_table = create_param_table(height=80) ctx = dash.callback_context @@ -113,12 +106,7 @@ def _dump_results(output_folder, graph_df, root_leaf_table, link_table): forbids.append([p["Node A"], p["Node B"]]) domain_knowledge = { - "causal-graph": { - "root-nodes": roots, - "leaf-nodes": leaves, - "forbids": forbids, - "requires": requires - } + "causal-graph": {"root-nodes": roots, "leaf-nodes": leaves, "forbids": forbids, "requires": requires} } causal_method.dump_results(output_folder, graph_df, domain_knowledge) @@ -132,7 +120,7 @@ def _dump_results(output_folder, graph_df, root_leaf_table, link_table): Input("causal-run-btn", "n_clicks"), Input("causal-exception-modal-close", "n_clicks"), Input("upload-graph", "filename"), - Input("upload-graph", "contents") + Input("upload-graph", "contents"), ], [ State("causal-select-file", "value"), @@ -142,7 +130,7 @@ def _dump_results(output_folder, graph_df, root_leaf_table, link_table): State("causal-data-state", "data"), State("cytoscape", "elements"), State("root-leaf-table", "children"), - State("link-table", "children") + State("link-table", "children"), ], running=[ (Output("causal-run-btn", "disabled"), True, False), @@ -151,29 +139,27 @@ def _dump_results(output_folder, graph_df, root_leaf_table, link_table): cancel=[Input("causal-cancel-btn", "n_clicks")], background=True, manager=file_manager.get_long_callback_manager(), - prevent_initial_call=True + prevent_initial_call=True, ) def click_train_test( - run_clicks, - modal_close, - upload_graph_file, - upload_graph_content, - filename, - algorithm, - param_table, - causal_state, - data_state, - cyto_elements, - root_leaf_table, - link_table + run_clicks, + modal_close, + upload_graph_file, + upload_graph_content, + filename, + algorithm, + param_table, + causal_state, + data_state, + cyto_elements, + root_leaf_table, + link_table, ): ctx = dash.callback_context modal_is_open = False modal_content = "" - state = json.loads(causal_state) \ - if causal_state is not None else {} - data_state = json.loads(data_state) \ - if data_state is not None else {} + state = json.loads(causal_state) if causal_state is not None else {} + data_state = json.loads(data_state) if data_state is not None else {} def _update_states(graph, graph_df, relations): causal_levels, cycles = causal_method.causal_order(graph_df) @@ -239,37 +225,29 @@ def hover_graph_node(data): Output("causal-relationship-table", "children"), Output("causal-cycle-table", "children"), Input("causal-state", "data"), - prevent_initial_call=True + prevent_initial_call=True, ) def update_view(data): - state = json.loads(data) \ - if data is not None else {} + state = json.loads(data) if data is not None else {} graph = state.get("graph", []) positions = state.get("positions", {}) if state.get("cycles", None) is not None: - cycle_table = html.Div(children=[ - html.B("Cyclic Paths"), - html.Hr(), - create_cycle_table(state["cycles"]) - ]) + cycle_table = html.Div(children=[html.B("Cyclic Paths"), html.Hr(), create_cycle_table(state["cycles"])]) else: cycle_table = None for element in graph: if "position" in element: - element["position"] = \ - positions.get(element["data"]["id"], element["position"]) + element["position"] = positions.get(element["data"]["id"], element["position"]) - return graph, \ - create_causal_relation_table(state.get("relations", None)), \ - cycle_table + return graph, create_causal_relation_table(state.get("relations", None)), cycle_table @callback( Output("add-root-leaf-node", "options"), Input("add-root-leaf-node-parent", "n_clicks"), - State("causal-data-state", "data") + State("causal-data-state", "data"), ) def update_root_leaf_dropdown(n_clicks, data_state): options = [] @@ -281,11 +259,7 @@ def update_root_leaf_dropdown(n_clicks, data_state): return options -@callback( - Output("add-node-A", "options"), - Input("add-node-A-parent", "n_clicks"), - State("causal-data-state", "data") -) +@callback(Output("add-node-A", "options"), Input("add-node-A-parent", "n_clicks"), State("causal-data-state", "data")) def update_node_a_dropdown(n_clicks, data_state): options = [] ctx = dash.callback_context @@ -296,11 +270,7 @@ def update_node_a_dropdown(n_clicks, data_state): return options -@callback( - Output("add-node-B", "options"), - Input("add-node-B-parent", "n_clicks"), - State("causal-data-state", "data") -) +@callback(Output("add-node-B", "options"), Input("add-node-B-parent", "n_clicks"), State("causal-data-state", "data")) def update_node_b_dropdown(n_clicks, data_state): options = [] ctx = dash.callback_context @@ -316,11 +286,7 @@ def update_node_b_dropdown(n_clicks, data_state): Input("select-domain", "value"), Input("add-root-leaf-btn", "n_clicks"), Input("delete-root-leaf-btn", "n_clicks"), - [ - State("add-root-leaf-node", "value"), - State("root-leaf-check", "value"), - State("root-leaf-table", "children") - ] + [State("add-root-leaf-node", "value"), State("root-leaf-check", "value"), State("root-leaf-table", "children")], ) def add_delete_root_leaf_node(domain_file, add_click, delete_click, metric, is_root, table): ctx = dash.callback_context @@ -359,8 +325,8 @@ def add_delete_root_leaf_node(domain_file, add_click, delete_click, metric, is_r State("add-node-A", "value"), State("add-node-B", "value"), State("link_radio_button", "value"), - State("link-table", "children") - ] + State("link-table", "children"), + ], ) def add_link(domain_file, add_click, delete_click, node_a, node_b, link_type, table): ctx = dash.callback_context @@ -368,8 +334,7 @@ def add_link(domain_file, add_click, delete_click, node_a, node_b, link_type, ta if table is not None: if isinstance(table, list): table = table[0] - links = {(p["Node A"], p["Node B"]): p["Type"] - for p in table["props"]["data"] if p["Node A"]} + links = {(p["Node A"], p["Node B"]): p["Type"] for p in table["props"]["data"] if p["Node A"]} if ctx.triggered: prop_id = ctx.triggered_id @@ -395,11 +360,8 @@ def add_link(domain_file, add_click, delete_click, node_a, node_b, link_type, ta Output("download-data", "data"), Output("data-download-exception-modal", "is_open"), Output("data-download-exception-modal-content", "children"), - [ - Input("causal-download-btn", "n_clicks"), - Input("data-download-exception-modal-close", "n_clicks") - ], - State("causal-select-file", "value") + [Input("causal-download-btn", "n_clicks"), Input("data-download-exception-modal-close", "n_clicks")], + State("causal-select-file", "value"), ) def download(btn_click, modal_close, filename): ctx = dash.callback_context @@ -411,8 +373,7 @@ def download(btn_click, modal_close, filename): prop_id = ctx.triggered_id if prop_id == "causal-download-btn" and btn_click > 0: try: - assert filename, "Please select the dataset name " \ - "to download the generated causal graph." + assert filename, "Please select the dataset name " "to download the generated causal graph." name = filename.split(".")[0] path = file_manager.get_model_download_path(name) data = dcc.send_file(path) diff --git a/pyrca/tools/dashboard/callbacks/data.py b/pyrca/tools/dashboard/callbacks/data.py index 8790b1c..e9eb430 100644 --- a/pyrca/tools/dashboard/callbacks/data.py +++ b/pyrca/tools/dashboard/callbacks/data.py @@ -1,9 +1,8 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import json import dash from dash import Input, Output, State, callback @@ -18,10 +17,7 @@ @callback( Output("select-file", "options"), Output("select-file", "value"), - [ - Input("upload-data", "filename"), - Input("upload-data", "contents") - ], + [Input("upload-data", "filename"), Input("upload-data", "contents")], ) def upload_file(filenames, contents): name = None @@ -46,7 +42,7 @@ def upload_file(filenames, contents): Input("data-btn", "n_clicks"), Input("thres-btn", "n_clicks"), Input("manual-btn", "n_clicks"), - Input("data-exception-modal-close", "n_clicks") + Input("data-exception-modal-close", "n_clicks"), ], [ State("select-file", "value"), @@ -54,20 +50,11 @@ def upload_file(filenames, contents): State("sigma", "value"), State("lower_bound", "value"), State("upper_bound", "value"), - State("data-state", "data") - ] + State("data-state", "data"), + ], ) def click_run( - btn_click, - thres_click, - manual_click, - modal_close, - file_name, - thres_column, - sigma, - lower_bound, - upper_bound, - data + btn_click, thres_click, manual_click, modal_close, file_name, thres_column, sigma, lower_bound, upper_bound, data ): ctx = dash.callback_context stats = json.loads(data) if data is not None else {} @@ -90,35 +77,21 @@ def click_run( data_figure = DataAnalyzer.get_data_figure(df) elif prop_id == "thres-btn" and thres_click > 0: - data_figure = data_analyzer.estimate_threshold( - column=thres_column, - sigma=float(sigma) - ) + data_figure = data_analyzer.estimate_threshold(column=thres_column, sigma=float(sigma)) elif prop_id == "manual-btn" and manual_click > 0: data_figure = data_analyzer.manual_threshold( - column=thres_column, - lower=float(lower_bound), - upper=float(upper_bound) + column=thres_column, lower=float(lower_bound), upper=float(upper_bound) ) except Exception as error: modal_is_open = True modal_content = str(error) - return stats_table, \ - json.dumps(stats), \ - data_table, \ - data_figure, \ - modal_is_open, \ - modal_content + return stats_table, json.dumps(stats), data_table, data_figure, modal_is_open, modal_content -@callback( - Output("select-column", "options"), - Input("select-column-parent", "n_clicks"), - State("data-state", "data") -) +@callback(Output("select-column", "options"), Input("select-column-parent", "n_clicks"), State("data-state", "data")) def update_metric_dropdown(n_clicks, data): options = [] ctx = dash.callback_context @@ -130,11 +103,7 @@ def update_metric_dropdown(n_clicks, data): return options -@callback( - Output("metric-stats-table", "children"), - Input("select-column", "value"), - State("data-state", "data") -) +@callback(Output("metric-stats-table", "children"), Input("select-column", "value"), State("data-state", "data")) def update_metric_table(column, data): ctx = dash.callback_context metric_stats_table = create_metric_stats_table() @@ -150,7 +119,7 @@ def update_metric_table(column, data): @callback( Output("select-thres-column", "options"), Input("select-thres-column-parent", "n_clicks"), - State("data-state", "data") + State("data-state", "data"), ) def update_thres_dropdown(n_clicks, data): options = [] @@ -166,7 +135,7 @@ def update_thres_dropdown(n_clicks, data): @callback( Output("select-manual-column", "options"), Input("select-manual-column-parent", "n_clicks"), - State("data-state", "data") + State("data-state", "data"), ) def update_manual_dropdown(n_clicks, data): options = [] diff --git a/pyrca/tools/dashboard/dashboard.py b/pyrca/tools/dashboard/dashboard.py index 4a82395..ecefc5f 100644 --- a/pyrca/tools/dashboard/dashboard.py +++ b/pyrca/tools/dashboard/dashboard.py @@ -1,9 +1,8 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import dash import dash_bootstrap_components as dbc from dash import dcc @@ -25,28 +24,23 @@ title="PyRCA Dashboard", ) app.config["suppress_callback_exceptions"] = True -app.layout = html.Div([ - dcc.Location(id="url", refresh=False), - html.Div(id="page-content"), - dcc.Store(id="data-state"), - dcc.Store(id="causal-state"), - dcc.Store(id="causal-data-state") -]) +app.layout = html.Div( + [ + dcc.Location(id="url", refresh=False), + html.Div(id="page-content"), + dcc.Store(id="data-state"), + dcc.Store(id="causal-state"), + dcc.Store(id="causal-data-state"), + ] +) server = app.server -@app.callback( - Output("page-content", "children"), - [Input("url", "pathname")] -) +@app.callback(Output("page-content", "children"), [Input("url", "pathname")]) def _display_page(pathname): return html.Div( id="app-container", - children=[ - create_banner(app), - html.Br(), - create_layout() - ], + children=[create_banner(app), html.Br(), create_layout()], ) diff --git a/pyrca/tools/dashboard/pages/__init__.py b/pyrca/tools/dashboard/pages/__init__.py index 599d47c..d0e4276 100644 --- a/pyrca/tools/dashboard/pages/__init__.py +++ b/pyrca/tools/dashboard/pages/__init__.py @@ -1,6 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/tools/dashboard/pages/causal.py b/pyrca/tools/dashboard/pages/causal.py index dfb935b..8b61532 100644 --- a/pyrca/tools/dashboard/pages/causal.py +++ b/pyrca/tools/dashboard/pages/causal.py @@ -1,9 +1,8 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import networkx as nx import dash_cytoscape as cyto @@ -15,24 +14,24 @@ default_stylesheet = [ { - 'selector': 'node', - 'style': { - 'label': 'data(label)', - 'opacity': 'data(weight)', - 'background-color': '#1f77b4', - } + "selector": "node", + "style": { + "label": "data(label)", + "opacity": "data(weight)", + "background-color": "#1f77b4", + }, }, { - 'selector': 'edge', - 'style': { - 'curve-style': 'bezier', - 'target-arrow-color': 'black', - 'target-arrow-shape': 'triangle', - 'arrow-scale': 1, - 'line-color': 'black', - 'opacity': 0.6, - 'width': 0.5 - } + "selector": "edge", + "style": { + "curve-style": "bezier", + "target-arrow-color": "black", + "target-arrow-shape": "triangle", + "arrow-scale": 1, + "line-color": "black", + "opacity": 0.6, + "width": 0.5, + }, }, ] @@ -57,8 +56,7 @@ def build_cyto_graph(graph, levels, positions, max_node_name_length=20): label = label[:max_node_name_length] + "*" data = { "data": {"id": node, "label": label}, - "position": positions.get( - node, {"x": int(node2pos[node][1]), "y": int(node2pos[node][0])}) + "position": positions.get(node, {"x": int(node2pos[node][1]), "y": int(node2pos[node][0])}), } cy_nodes.append(data) for edge in graph.edges(): @@ -87,7 +85,7 @@ def create_causal_relation_table(relations=None, height=500): columns=[ {"id": "Node A", "name": "Node A"}, {"id": "Relation", "name": "Relation"}, - {"id": "Node B", "name": "Node B"} + {"id": "Node B", "name": "Node B"}, ], editable=False, sort_action="native", @@ -95,11 +93,7 @@ def create_causal_relation_table(relations=None, height=500): style_cell_conditional=[{"textAlign": "center"}], style_header=dict(backgroundColor=TABLE_HEADER_COLOR), style_data=dict(backgroundColor=TABLE_DATA_COLOR), - style_table={ - "overflowX": "scroll", - "overflowY": "scroll", - "height": height - }, + style_table={"overflowX": "scroll", "overflowY": "scroll", "height": height}, ) return table @@ -108,8 +102,7 @@ def create_cycle_table(cycles, height=100): if cycles is None or len(cycles) == 0: data = [{"Cyclic Path": ""}] else: - data = [{"Cyclic Path": " --> ".join([str(node) for node in path])} - for path in cycles] + data = [{"Cyclic Path": " --> ".join([str(node) for node in path])} for path in cycles] table = dash_table.DataTable( id="causal-cycles", @@ -123,11 +116,7 @@ def create_cycle_table(cycles, height=100): style_cell_conditional=[{"textAlign": "center"}], style_header=dict(backgroundColor=TABLE_HEADER_COLOR), style_data=dict(backgroundColor=TABLE_DATA_COLOR), - style_table={ - "overflowX": "scroll", - "overflowY": "scroll", - "height": height - }, + style_table={"overflowX": "scroll", "overflowY": "scroll", "height": height}, ) return table @@ -136,8 +125,7 @@ def create_root_leaf_table(metrics=None, height=80): if metrics is None or len(metrics) == 0: data = [{"Metric": "", "Type": ""}] else: - data = [{"Metric": metric["name"], "Type": metric["type"]} - for metric in metrics] + data = [{"Metric": metric["name"], "Type": metric["type"]} for metric in metrics] table = dash_table.DataTable( data=data, @@ -151,11 +139,7 @@ def create_root_leaf_table(metrics=None, height=80): style_cell_conditional=[{"textAlign": "center"}], style_header=dict(backgroundColor=TABLE_HEADER_COLOR), style_data=dict(backgroundColor=TABLE_DATA_COLOR), - style_table={ - "overflowX": "scroll", - "overflowY": "scroll", - "height": height - }, + style_table={"overflowX": "scroll", "overflowY": "scroll", "height": height}, ) return table @@ -164,8 +148,7 @@ def create_link_table(links=None, height=80): if links is None or len(links) == 0: data = [{"Node A": "", "Type": "", "Node B": ""}] else: - data = [{"Node A": link["A"], "Type": link["type"], "Node B": link["B"]} - for link in links] + data = [{"Node A": link["A"], "Type": link["type"], "Node B": link["B"]} for link in links] table = dash_table.DataTable( data=data, @@ -180,11 +163,7 @@ def create_link_table(links=None, height=80): style_cell_conditional=[{"textAlign": "center"}], style_header=dict(backgroundColor=TABLE_HEADER_COLOR), style_data=dict(backgroundColor=TABLE_DATA_COLOR), - style_table={ - "overflowX": "scroll", - "overflowY": "scroll", - "height": height - }, + style_table={"overflowX": "scroll", "overflowY": "scroll", "height": height}, ) return table @@ -214,145 +193,104 @@ def create_control_panel() -> html.Div: }, multiple=True, ), - html.Br(), html.P("Select Data File"), html.Div( id="causal-select-file-parent", - children=[ - dcc.Dropdown( - id="causal-select-file", - options=[], - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="causal-select-file", options=[], style={"width": "100%"})], ), - html.Br(), html.P("Causal Discovery Algorithm"), html.Div( id="select-causal-method-parent", - children=[ - dcc.Dropdown( - id="select-causal-method", - options=[], - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="select-causal-method", options=[], style={"width": "100%"})], ), - html.Div( - id="causal-param-table", - children=[create_param_table(height=80)] - ), - + html.Div(id="causal-param-table", children=[create_param_table(height=80)]), html.Br(), html.P("Select Domain Knowledge File"), html.Div( id="select-domain-parent", - children=[ - dcc.Dropdown( - id="select-domain", - options=[], - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="select-domain", options=[], style={"width": "100%"})], ), - html.Br(), html.Div( children=[ html.Button(id="causal-run-btn", children="Run", n_clicks=0), html.Button(id="causal-cancel-btn", children="Cancel", style={"margin-left": "10px"}), html.Button(id="causal-download-btn", children="Download", style={"margin-left": "10px"}), - dcc.Download(id="download-data") + dcc.Download(id="download-data"), ], - style={"textAlign": "center"} + style={"textAlign": "center"}, ), - html.Br(), html.Hr(), html.P("Edit Domain Knowledge"), html.Label("Root or Leaf"), - html.Div(children=[ - html.Div( - id="add-root-leaf-node-parent", - children=[dcc.Dropdown(id="add-root-leaf-node", options=[])], - style={"width": "80%"} - ), - html.Div( - id="add-root-leaf-node-parent", - children=[ - dcc.Checklist( - id="root-leaf-check", - options=[ - {"label": " Is Root", "value": "root"}, - ], - value=["root"] - ) - ], - style={"width": "20%", "margin-left": "15px"} - )], - style=dict(display="flex") - ), html.Div( - id="root-leaf-table", - children=[create_root_leaf_table()] + children=[ + html.Div( + id="add-root-leaf-node-parent", + children=[dcc.Dropdown(id="add-root-leaf-node", options=[])], + style={"width": "80%"}, + ), + html.Div( + id="add-root-leaf-node-parent", + children=[ + dcc.Checklist( + id="root-leaf-check", + options=[ + {"label": " Is Root", "value": "root"}, + ], + value=["root"], + ) + ], + style={"width": "20%", "margin-left": "15px"}, + ), + ], + style=dict(display="flex"), ), + html.Div(id="root-leaf-table", children=[create_root_leaf_table()]), html.Br(), html.Div( children=[ html.Button(id="add-root-leaf-btn", children="Add", n_clicks=0), html.Button(id="delete-root-leaf-btn", children="Delete", style={"margin-left": "15px"}), ], - style={"textAlign": "center"} + style={"textAlign": "center"}, ), - html.Br(), html.Label("Forbidden or Required Links"), html.Div( id="add-node-A-parent", - children=[ - dcc.Dropdown( - id="add-node-A", - options=[], - placeholder="Node A", - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="add-node-A", options=[], placeholder="Node A", style={"width": "100%"})], ), html.Div( children=dcc.RadioItems( id="link_radio_button", options=[ {"label": " ⇒ (Required)", "value": "Required"}, - {"label": " ⇏ (Forbidden)", "value": "Forbidden"} + {"label": " ⇏ (Forbidden)", "value": "Forbidden"}, ], value="Required", inline=True, - inputStyle={"margin-left": "20px"} + inputStyle={"margin-left": "20px"}, ) ), html.Div( id="add-node-B-parent", - children=[ - dcc.Dropdown( - id="add-node-B", - options=[], - placeholder="Node B", - style={"width": "100%"} - )] - ), - html.Div( - id="link-table", - children=[create_link_table()] + children=[dcc.Dropdown(id="add-node-B", options=[], placeholder="Node B", style={"width": "100%"})], ), + html.Div(id="link-table", children=[create_link_table()]), html.Br(), html.Div( children=[ html.Button(id="add-link-btn", children="Add", n_clicks=0), html.Button(id="delete-link-btn", children="Delete", style={"margin-left": "15px"}), ], - style={"textAlign": "center"} + style={"textAlign": "center"}, ), html.Br(), html.Hr(), - html.P(id="label", children="Open Causal Graph"), dcc.Upload( id="upload-graph", @@ -373,13 +311,12 @@ def create_control_panel() -> html.Div: }, multiple=True, ), - create_modal( modal_id="causal-exception-modal", header="An Exception Occurred", content="An exception occurred. Please click OK to continue.", content_id="causal-exception-modal-content", - button_id="causal-exception-modal-close" + button_id="causal-exception-modal-close", ), create_modal( modal_id="data-download-exception-modal", @@ -411,12 +348,12 @@ def create_right_column() -> html.Div: style={"height": "60vh", "width": "100%"}, minZoom=0.5, maxZoom=4.0, - layout={"name": "preset"} + layout={"name": "preset"}, ), - html.B(id="cytoscape-hover-output") - ] - ) - ] + html.B(id="cytoscape-hover-output"), + ], + ), + ], ), html.Div( id="result_table_card", @@ -424,10 +361,10 @@ def create_right_column() -> html.Div: html.Div(id="causal-cycle-table"), html.B("Causal Relationships"), html.Hr(), - html.Div(id="causal-relationship-table") - ] - ) - ] + html.Div(id="causal-relationship-table"), + ], + ), + ], ) @@ -439,14 +376,9 @@ def create_causal_layout() -> html.Div: html.Div( id="left-column-data", className="three columns", - children=[ - create_control_panel() - ], + children=[create_control_panel()], ), # Right column - html.Div( - className="nine columns", - children=create_right_column() - ) - ] + html.Div(className="nine columns", children=create_right_column()), + ], ) diff --git a/pyrca/tools/dashboard/pages/data.py b/pyrca/tools/dashboard/pages/data.py index 5e2e71a..376e7e0 100644 --- a/pyrca/tools/dashboard/pages/data.py +++ b/pyrca/tools/dashboard/pages/data.py @@ -1,9 +1,8 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from dash import dcc from dash import html, dash_table from .utils import create_modal, create_emtpy_figure @@ -14,45 +13,36 @@ def create_stats_table(data_stats=None): if data_stats is None or len(data_stats) == 0: data = [{"Stats": "", "Value": ""}] else: - data = [{"Stats": key, "Value": value} - for key, value in data_stats["@global"].items()] + data = [{"Stats": key, "Value": value} for key, value in data_stats["@global"].items()] table = dash_table.DataTable( id="data-stats", data=data, - columns=[ - {"id": "Stats", "name": "Stats"}, - {"id": "Value", "name": "Value"} - ], + columns=[{"id": "Stats", "name": "Stats"}, {"id": "Value", "name": "Value"}], editable=False, style_header_conditional=[{"textAlign": "center"}], style_cell_conditional=[{"textAlign": "center"}], style_header=dict(backgroundColor=TABLE_HEADER_COLOR), - style_data=dict(backgroundColor=TABLE_DATA_COLOR) + style_data=dict(backgroundColor=TABLE_DATA_COLOR), ) return table def create_metric_stats_table(metric_stats=None, column=None): - if metric_stats is None or len(metric_stats) == 0 \ - or column not in metric_stats: + if metric_stats is None or len(metric_stats) == 0 or column not in metric_stats: data = [{"Stats": "", "Value": ""}] else: - data = [{"Stats": key, "Value": value} - for key, value in metric_stats[column].items()] + data = [{"Stats": key, "Value": value} for key, value in metric_stats[column].items()] table = dash_table.DataTable( id="metric-stats", data=data, - columns=[ - {"id": "Stats", "name": "Stats"}, - {"id": "Value", "name": "Value"} - ], + columns=[{"id": "Stats", "name": "Stats"}, {"id": "Value", "name": "Value"}], editable=False, style_header_conditional=[{"textAlign": "center"}], style_cell_conditional=[{"textAlign": "center"}], style_header=dict(backgroundColor=TABLE_HEADER_COLOR), - style_data=dict(backgroundColor=TABLE_DATA_COLOR) + style_data=dict(backgroundColor=TABLE_DATA_COLOR), ) return table @@ -82,139 +72,86 @@ def create_control_panel() -> html.Div: }, multiple=True, ), - html.Br(), html.P("Select Data File"), html.Div( - id="select-file-parent", - children=[ - dcc.Dropdown( - id="select-file", - options=[], - style={"width": "100%"} - )] + id="select-file-parent", children=[dcc.Dropdown(id="select-file", options=[], style={"width": "100%"})] ), - html.Br(), html.P("Overall Stats"), html.Div( id="data-stats-table", children=[create_metric_stats_table()], ), - html.Br(), html.P("Metric/Variable Stats"), html.Div( id="select-column-parent", - children=[ - dcc.Dropdown( - id="select-column", - options=[], - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="select-column", options=[], style={"width": "100%"})], ), html.Div( id="metric-stats-table", children=[create_stats_table()], ), - html.Br(), - html.Div( - children=[ - html.Button(id="data-btn", children="Load", n_clicks=0) - ], - style={"textAlign": "center"} - ), - + html.Div(children=[html.Button(id="data-btn", children="Load", n_clicks=0)], style={"textAlign": "center"}), html.Br(), html.P("Estimate Threshold"), html.Label("Metric"), html.Div( id="select-thres-column-parent", - children=[ - dcc.Dropdown( - id="select-thres-column", - options=[], - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="select-thres-column", options=[], style={"width": "100%"})], ), html.Label("Sigma"), html.Div( id="sigma_input", children=[ - dcc.Input( - id="sigma", - type="number", - placeholder="Sigma value", - value=4, - style={"width": "100%"} - )] + dcc.Input(id="sigma", type="number", placeholder="Sigma value", value=4, style={"width": "100%"}) + ], ), html.Br(), html.Div( - children=[ - html.Button(id="thres-btn", children="Estimate", n_clicks=0) - ], - style={"textAlign": "center"} + children=[html.Button(id="thres-btn", children="Estimate", n_clicks=0)], style={"textAlign": "center"} ), - html.Br(), html.P("Manual Threshold"), html.Label("Metric"), html.Div( id="select-manual-column-parent", - children=[ - dcc.Dropdown( - id="select-manual-column", - options=[], - style={"width": "100%"} - )] + children=[dcc.Dropdown(id="select-manual-column", options=[], style={"width": "100%"})], ), html.Label("Lower"), html.Div( id="lower_input", children=[ - dcc.Input( - id="lower_bound", - type="number", - placeholder="Lower bound", - style={"width": "100%"} - )] + dcc.Input(id="lower_bound", type="number", placeholder="Lower bound", style={"width": "100%"}) + ], ), html.Label("Upper"), html.Div( id="upper_input", children=[ - dcc.Input( - id="upper_bound", - type="number", - placeholder="Upper bound", - style={"width": "100%"} - )] + dcc.Input(id="upper_bound", type="number", placeholder="Upper bound", style={"width": "100%"}) + ], ), html.Br(), html.Div( - children=[ - html.Button(id="manual-btn", children="Test", n_clicks=0) - ], - style={"textAlign": "center"} + children=[html.Button(id="manual-btn", children="Test", n_clicks=0)], style={"textAlign": "center"} ), - create_modal( modal_id="data-exception-modal", header="An Exception Occurred", content="An exception occurred. Please click OK to continue.", content_id="data-exception-modal-content", - button_id="data-exception-modal-close" + button_id="data-exception-modal-close", ), - create_modal( modal_id="data-download-exception-modal", header="An Exception Occurred", content="An exception occurred. Please click OK to continue.", content_id="data-download-exception-modal-content", - button_id="data-download-exception-modal-close" - ) + button_id="data-download-exception-modal-close", + ), ], ) @@ -228,21 +165,13 @@ def create_right_column() -> html.Div: children=[ html.B("Time Series Plots"), html.Hr(), - html.Div( - id="data-plots", - children=[create_emtpy_figure()] - ) - ] + html.Div(id="data-plots", children=[create_emtpy_figure()]), + ], ), html.Div( - id="result_table_card", - children=[ - html.B("Time Series Samples"), - html.Hr(), - html.Div(id="data-table") - ] - ) - ] + id="result_table_card", children=[html.B("Time Series Samples"), html.Hr(), html.Div(id="data-table")] + ), + ], ) @@ -254,14 +183,9 @@ def create_data_layout() -> html.Div: html.Div( id="left-column-data", className="three columns", - children=[ - create_control_panel() - ], + children=[create_control_panel()], ), # Right column - html.Div( - className="nine columns", - children=create_right_column() - ) - ] + html.Div(className="nine columns", children=create_right_column()), + ], ) diff --git a/pyrca/tools/dashboard/pages/utils.py b/pyrca/tools/dashboard/pages/utils.py index 8df8318..48d315f 100644 --- a/pyrca/tools/dashboard/pages/utils.py +++ b/pyrca/tools/dashboard/pages/utils.py @@ -1,26 +1,21 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import dash_bootstrap_components as dbc from dash import html, dash_table from ..settings import * styles = { - 'json-output': { - 'overflow-y': 'scroll', - 'height': 'calc(90% - 25px)', - 'border': 'thin lightgrey solid' + "json-output": {"overflow-y": "scroll", "height": "calc(90% - 25px)", "border": "thin lightgrey solid"}, + "tab": {"height": "calc(98vh - 80px)"}, + "log-output": { + "overflow-y": "scroll", + "height": "calc(90% - 25px)", + "border": "thin lightgrey solid", + "white-space": "pre-wrap", }, - 'tab': {'height': 'calc(98vh - 80px)'}, - 'log-output': { - 'overflow-y': 'scroll', - 'height': 'calc(90% - 25px)', - 'border': 'thin lightgrey solid', - 'white-space': 'pre-wrap' - }, } @@ -31,11 +26,7 @@ def create_modal(modal_id, header, content, content_id, button_id): [ dbc.ModalHeader(dbc.ModalTitle(header)), dbc.ModalBody(content, id=content_id), - dbc.ModalFooter( - dbc.Button( - "Close", id=button_id, className="ml-auto", n_clicks=0 - ) - ), + dbc.ModalFooter(dbc.Button("Close", id=button_id, className="ml-auto", n_clicks=0)), ], id=modal_id, is_open=False, @@ -49,25 +40,17 @@ def create_param_table(params=None, height=100): if params is None or len(params) == 0: data = [{"Parameter": "", "Value": ""}] else: - data = [{"Parameter": key, "Value": str(value["default"])} - for key, value in params.items()] + data = [{"Parameter": key, "Value": str(value["default"])} for key, value in params.items()] table = dash_table.DataTable( data=data, - columns=[ - {"id": "Parameter", "name": "Parameter"}, - {"id": "Value", "name": "Value"} - ], + columns=[{"id": "Parameter", "name": "Parameter"}, {"id": "Value", "name": "Value"}], editable=True, style_header_conditional=[{"textAlign": "center"}], style_cell_conditional=[{"textAlign": "center"}], - style_table={ - "overflowX": "scroll", - "overflowY": "scroll", - "height": height - }, + style_table={"overflowX": "scroll", "overflowY": "scroll", "height": height}, style_header=dict(backgroundColor=TABLE_HEADER_COLOR), - style_data=dict(backgroundColor=TABLE_DATA_COLOR) + style_data=dict(backgroundColor=TABLE_DATA_COLOR), ) return table @@ -93,7 +76,7 @@ def create_metric_table(metrics=None): style_cell_conditional=[{"textAlign": "center"}], style_table={"overflowX": "scroll"}, style_header=dict(backgroundColor=TABLE_HEADER_COLOR), - style_data=dict(backgroundColor=TABLE_DATA_COLOR) + style_data=dict(backgroundColor=TABLE_DATA_COLOR), ) return table diff --git a/pyrca/tools/dashboard/utils/__init__.py b/pyrca/tools/dashboard/utils/__init__.py index 599d47c..d0e4276 100644 --- a/pyrca/tools/dashboard/utils/__init__.py +++ b/pyrca/tools/dashboard/utils/__init__.py @@ -1,6 +1,5 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/tools/dashboard/utils/layout.py b/pyrca/tools/dashboard/utils/layout.py index 9fc81f6..5c9c735 100644 --- a/pyrca/tools/dashboard/utils/layout.py +++ b/pyrca/tools/dashboard/utils/layout.py @@ -1,26 +1,21 @@ # -# Copyright (c) 2022 salesforce.com, inc. +# Copyright (c) 2023 salesforce.com, inc. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -# +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from dash import dcc from dash import html -tab_style = { - 'borderBottom': '1px solid #d6d6d6', - 'padding': '6px', - 'fontWeight': 'bold' -} +tab_style = {"borderBottom": "1px solid #d6d6d6", "padding": "6px", "fontWeight": "bold"} tab_selected_style = { - 'borderTop': '1px solid #d6d6d6', - 'borderBottom': '1px solid #d6d6d6', - 'backgroundColor': '#119DFF', - 'color': 'white', - 'padding': '6px', - 'fontWeight': 'bold' + "borderTop": "1px solid #d6d6d6", + "borderBottom": "1px solid #d6d6d6", + "backgroundColor": "#119DFF", + "color": "white", + "padding": "6px", + "fontWeight": "bold", } @@ -28,8 +23,7 @@ def create_banner(app): return html.Div( id="banner", className="banner", - children=[html.Img(src=app.get_asset_url("logo.png")), - html.Plaintext(" Powered by Salesforce AI Research")], + children=[html.Img(src=app.get_asset_url("logo.png")), html.Plaintext(" Powered by Salesforce AI Research")], ) @@ -37,26 +31,17 @@ def create_layout() -> html.Div: children, values = [], [] # Data analysis tab children.append( - dcc.Tab(label="Data Analysis", value="file-manager", - style=tab_style, selected_style=tab_selected_style) + dcc.Tab(label="Data Analysis", value="file-manager", style=tab_style, selected_style=tab_selected_style) ) values.append("file-manager") # Causal graph tab children.append( - dcc.Tab(label="Causal Discovery", value="causal-graph", - style=tab_style, selected_style=tab_selected_style) + dcc.Tab(label="Causal Discovery", value="causal-graph", style=tab_style, selected_style=tab_selected_style) ) values.append("causal-graph") layout = html.Div( id="app-content", - children=[ - dcc.Tabs( - id="tabs", - value=values[0] if values else "none", - children=children - ), - html.Div(id="plots") - ], + children=[dcc.Tabs(id="tabs", value=values[0] if values else "none", children=children), html.Div(id="plots")], ) return layout From a42381ac3b63af3bdcd83880e322b0717270d13f Mon Sep 17 00:00:00 2001 From: yangwenzhuo08 Date: Thu, 13 Apr 2023 16:35:34 +0800 Subject: [PATCH 2/2] Update copyright --- .pre-commit-config.yaml | 2 +- pyrca/__init__.py | 5 + pyrca/analyzers/__init__.py | 5 + pyrca/analyzers/base.py | 12 +- pyrca/analyzers/bayesian.py | 126 ++++++++++---------- pyrca/analyzers/epsilon_diagnosis.py | 38 +++--- pyrca/analyzers/psi_pc.py | 23 ++-- pyrca/analyzers/random_walk.py | 39 +++--- pyrca/analyzers/rht.py | 29 ++--- pyrca/applications/example/dataset.py | 5 + pyrca/applications/example/rca.py | 98 ++++++++------- pyrca/base.py | 22 ++-- pyrca/graphs/__init__.py | 5 + pyrca/graphs/causal/__init__.py | 16 +-- pyrca/graphs/causal/base.py | 76 ++++-------- pyrca/graphs/causal/fges.py | 39 +++--- pyrca/graphs/causal/ges.py | 30 +++-- pyrca/graphs/causal/lingam.py | 20 ++-- pyrca/graphs/causal/pc.py | 35 +++--- pyrca/outliers/__init__.py | 5 + pyrca/outliers/base.py | 39 +++--- pyrca/outliers/stats.py | 64 +++++----- pyrca/tools/__init__.py | 5 + pyrca/tools/app.py | 5 + pyrca/tools/dashboard/models/__init__.py | 5 + pyrca/tools/dashboard/models/causal.py | 10 +- pyrca/tools/dashboard/models/data.py | 19 ++- pyrca/tools/dashboard/settings.py | 5 + pyrca/tools/dashboard/utils/file_manager.py | 12 +- pyrca/tools/dashboard/utils/log.py | 5 + pyrca/tools/dashboard/utils/plot.py | 22 ++-- pyrca/utils/__init__.py | 5 + pyrca/utils/domain.py | 57 ++++----- pyrca/utils/logger.py | 5 + pyrca/utils/misc.py | 8 +- pyrca/utils/plot.py | 50 ++++---- pyrca/utils/utils.py | 35 ++---- setup.py | 15 ++- tests/analyzers/test_bayesian.py | 15 +-- tests/analyzers/test_epsilon_diagnosis.py | 6 +- tests/analyzers/test_psi_pc.py | 10 +- tests/analyzers/test_random_walk.py | 15 +-- tests/analyzers/test_rht.py | 12 +- tests/applications/example/run_rca.py | 35 +++--- tests/graphs/test_domain.py | 20 +++- tests/graphs/test_fges.py | 7 +- tests/graphs/test_ges.py | 7 +- tests/graphs/test_lingam.py | 14 ++- tests/graphs/test_pc.py | 7 +- tests/outliers/test_config.py | 34 +++--- tests/outliers/test_stats.py | 28 ++--- tests/tools/test_causal.py | 12 +- 52 files changed, 656 insertions(+), 562 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81b6769..f70b5f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: '22.3.0' + rev: '23.3.0' hooks: - id: black args: ["--line-length", "120"] diff --git a/pyrca/__init__.py b/pyrca/__init__.py index ad78568..f97c759 100644 --- a/pyrca/__init__.py +++ b/pyrca/__init__.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from pkg_resources import get_distribution, DistributionNotFound try: diff --git a/pyrca/analyzers/__init__.py b/pyrca/analyzers/__init__.py index e69de29..d0e4276 100644 --- a/pyrca/analyzers/__init__.py +++ b/pyrca/analyzers/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/analyzers/base.py b/pyrca/analyzers/base.py index 6fe9721..607b7b3 100644 --- a/pyrca/analyzers/base.py +++ b/pyrca/analyzers/base.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """Base classes for all RCA algorithms""" from abc import abstractmethod from dataclasses import dataclass, field, asdict @@ -18,6 +23,7 @@ class RCAResults: following format: (path_score, [(path_node_a, score_a), (path_node_b, score_b), ...]). If ``path_node_a`` has no score, ``score_a`` is set to None. """ + root_cause_nodes: list = field(default_factory=lambda: []) root_cause_paths: dict = field(default_factory=lambda: {}) @@ -33,11 +39,7 @@ def to_list(self) -> list: """ results = [] for node, score in self.root_cause_nodes: - results.append({ - "root_cause": node, - "score": score, - "paths": self.root_cause_paths.get(node, None) - }) + results.append({"root_cause": node, "score": score, "paths": self.root_cause_paths.get(node, None)}) return results diff --git a/pyrca/analyzers/bayesian.py b/pyrca/analyzers/bayesian.py index e02cf57..66f5609 100644 --- a/pyrca/analyzers/bayesian.py +++ b/pyrca/analyzers/bayesian.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The RCA method based on Bayesian inference. """ @@ -40,6 +45,7 @@ class BayesianNetworkConfig(BaseConfig): :param infer_method: Use "posterior" or "likelihood" when doing Bayesian inference. :param root_cause_top_k: The maximum number of root causes in the results. """ + graph: Union[pd.DataFrame, str] = None default_sigma: float = 4.0 thres_win_size: int = 5 @@ -53,6 +59,7 @@ class BayesianNetwork(BaseRCA): """ The RCA method based on Bayesian inference. """ + config_class = BayesianNetworkConfig def __init__(self, config: BayesianNetworkConfig): @@ -65,9 +72,7 @@ def __init__(self, config: BayesianNetworkConfig): with open(config.graph, "rb") as f: self.graph = pickle.load(f) else: - raise RuntimeError( - "The graph file format is not supported, " - "please choose a csv or pickle file.") + raise RuntimeError("The graph file format is not supported, " "please choose a csv or pickle file.") else: self.graph = config.graph self.bayesian_model = self._build_bayesian_network(self.graph) @@ -87,12 +92,7 @@ def _build_bayesian_network(graph): return BayesianModel(ebunch=edges) return None - def train( - self, - dfs: Union[pd.DataFrame, List[pd.DataFrame]], - detector: BaseModel = None, - **kwargs - ): + def train(self, dfs: Union[pd.DataFrame, List[pd.DataFrame]], detector: BaseModel = None, **kwargs): """ Estimates Bayesian network parameters given the training time series. @@ -106,8 +106,7 @@ def train( dfs = [dfs] if detector is None: - sigmas = {} if self.config.sigmas is None else \ - self.config.sigmas + sigmas = {} if self.config.sigmas is None else self.config.sigmas all_scores = [] for df in dfs: lowers, uppers = estimate_thresholds( @@ -115,7 +114,7 @@ def train( sigmas=sigmas, default_sigma=self.config.default_sigma, win_size=self.config.thres_win_size, - reduce=self.config.thres_reduce_func + reduce=self.config.thres_reduce_func, ) scores = (df.values > uppers).astype(int) + (df.values < lowers).astype(int) all_scores.append(scores) @@ -136,8 +135,7 @@ def _refine_parameters(self, lower_bound): # Set lower and upper bounds for node in self.bayesian_model.nodes(): cpd = self.bayesian_model.get_cpds(node) - assert cpd.values.shape[0] == 2, \ - f"The cardinality of the variable {cpd.variable} should be = 2." + assert cpd.values.shape[0] == 2, f"The cardinality of the variable {cpd.variable} should be = 2." cpd.values = np.clip(cpd.values, lower_bound, 1.0 - lower_bound) # Make sure that P(m=1|Sa) >= P(m=1|Sb) when Sa >= Sb @@ -158,17 +156,19 @@ def _refine(_index, _values, _num_vars, _mem, _inc=0.1): values = cpd.values.reshape((2, -1)) if values.shape[1] > 2: num_vars, mem = len(cpd.variables) - 1, {} - _refine(2 ** num_vars - 1, values, num_vars, mem) + _refine(2**num_vars - 1, values, num_vars, mem) new_values = np.zeros_like(values) new_values[1, :] = [mem[k] for k in range(values.shape[1])] new_values[0, :] = 1.0 - new_values[1, :] - self.bayesian_model.add_cpds(TabularCPD( - variable=cpd.variable, - variable_card=2, - values=new_values, - evidence=cpd.variables[1:], - evidence_card=cpd.cardinality[1:] - )) + self.bayesian_model.add_cpds( + TabularCPD( + variable=cpd.variable, + variable_card=2, + values=new_values, + evidence=cpd.variables[1:], + evidence_card=cpd.cardinality[1:], + ) + ) def _infer(self, variables, evidence): model_infer = VariableElimination(self.bayesian_model) @@ -188,8 +188,9 @@ def _add_root_cause(self, root_cause_name, metric_name, root_cause_probs, root_p :param root_cause_probs: [P(metric=0 | root=0), P(metric=0 | root=1)] :param root_prob: P(root=1) """ - assert len(root_cause_probs), \ - "root_cause_probs should contain two values: P(metric=0 | root=0), P(metric=0 | root=1)" + assert len( + root_cause_probs + ), "root_cause_probs should contain two values: P(metric=0 | root=0), P(metric=0 | root=1)" if metric_name not in self.bayesian_model.nodes(): print(f"WARNING: Metric {metric_name} is not in the Bayesian network.") self.bayesian_model.add_node(metric_name) @@ -197,28 +198,36 @@ def _add_root_cause(self, root_cause_name, metric_name, root_cause_probs, root_p self.bayesian_model.add_node(root_cause_name) self.root_nodes.append(root_cause_name) self.bayesian_model.add_edge(root_cause_name, metric_name) - self.bayesian_model.add_cpds(TabularCPD( - variable=root_cause_name, variable_card=2, values=[[1 - root_prob], [root_prob]] - )) + self.bayesian_model.add_cpds( + TabularCPD(variable=root_cause_name, variable_card=2, values=[[1 - root_prob], [root_prob]]) + ) cpd = self.bayesian_model.get_cpds(metric_name) if cpd is None or cpd.values.size == 2: - self.bayesian_model.add_cpds(TabularCPD( - variable=metric_name, variable_card=2, - values=[root_cause_probs, [1 - root_cause_probs[0], 1 - root_cause_probs[1]]], - evidence=[root_cause_name], evidence_card=[2] - )) + self.bayesian_model.add_cpds( + TabularCPD( + variable=metric_name, + variable_card=2, + values=[root_cause_probs, [1 - root_cause_probs[0], 1 - root_cause_probs[1]]], + evidence=[root_cause_name], + evidence_card=[2], + ) + ) else: v = cpd.values.reshape((2, -1)) u = np.zeros(v.shape, dtype=float) u[0, :] = root_cause_probs[1] u[1, :] = 1 - root_cause_probs[1] evidence = [root_cause_name] + cpd.variables[1:] - self.bayesian_model.add_cpds(TabularCPD( - variable=metric_name, variable_card=2, - values=np.concatenate([v, u], axis=1), - evidence=evidence, evidence_card=[2] * len(evidence) - )) + self.bayesian_model.add_cpds( + TabularCPD( + variable=metric_name, + variable_card=2, + values=np.concatenate([v, u], axis=1), + evidence=evidence, + evidence_card=[2] * len(evidence), + ) + ) def add_root_causes(self, root_causes: List): """ @@ -233,11 +242,9 @@ def add_root_causes(self, root_causes: List): self._add_root_cause( root_cause_name=r["name"], metric_name=metric["name"], - root_cause_probs=[ - metric.get("P(m=0|r=0)", 0.99), - metric.get("P(m=0|r=1)", 0.01) - ], - root_prob=r["P(r=1)"]) + root_cause_probs=[metric.get("P(m=0|r=0)", 0.99), metric.get("P(m=0|r=1)", 0.01)], + root_prob=r["P(r=1)"], + ) def update_probability(self, target_node: str, parent_nodes: List, prob: float): """ @@ -288,7 +295,7 @@ def _get_all_paths(self, node): paths, flags = [], {} for path in all_paths: - p = '_'.join(path) + p = "_".join(path) if p not in flags: paths.append(path) flags[p] = True @@ -318,8 +325,7 @@ def _get_path_root_cause_scores(self, paths, evidence, node_scores, overwrite_sc return score_paths def _argument_root_nodes(self): - existing_roots = [ - str(node).replace("ROOT_", "") for node in self.root_nodes] + existing_roots = [str(node).replace("ROOT_", "") for node in self.root_nodes] nodes = [] for i, values in enumerate(self.graph.values.T): @@ -333,8 +339,9 @@ def _argument_root_nodes(self): { "name": f"ROOT_{node}", "P(r=1)": 0.5, - "metrics": [{"name": node, "P(m=0|r=0)": 0.99, "P(m=0|r=1)": 0.01}] - } for node in nodes + "metrics": [{"name": node, "P(m=0|r=0)": 0.99, "P(m=0|r=1)": 0.01}], + } + for node in nodes ] self.add_root_causes(root_nodes) @@ -353,11 +360,11 @@ def _post_process(self, all_paths): return paths def find_root_causes( - self, - anomalous_metrics: Union[List, Dict], - set_zero_path_score_for_normal_metrics: bool = False, - remove_zero_score_node_in_path: bool = True, - **kwargs + self, + anomalous_metrics: Union[List, Dict], + set_zero_path_score_for_normal_metrics: bool = False, + remove_zero_score_node_in_path: bool = True, + **kwargs, ) -> List: """ Finds the root causes given the observed anomalous metrics. @@ -372,11 +379,9 @@ def find_root_causes( self._argument_root_nodes() if isinstance(anomalous_metrics, Dict): - evidence = {metric: v for metric, v in anomalous_metrics.items() - if metric in self.bayesian_model.nodes()} + evidence = {metric: v for metric, v in anomalous_metrics.items() if metric in self.bayesian_model.nodes()} else: - evidence = {metric: 1 for metric in anomalous_metrics - if metric in self.bayesian_model.nodes()} + evidence = {metric: 1 for metric in anomalous_metrics if metric in self.bayesian_model.nodes()} # Pick the paths which contain anomalous node valid_paths = {} @@ -410,8 +415,9 @@ def find_root_causes( for root, score in root_scores: res = {"root_cause": root, "score": score, "paths": []} paths = valid_paths[root] - res["paths"] = self._get_path_root_cause_scores( - paths, evidence, node_scores)[:self.config.root_cause_top_k] + res["paths"] = self._get_path_root_cause_scores(paths, evidence, node_scores)[ + : self.config.root_cause_top_k + ] results.append(res) results = sorted(results, key=lambda r: (r["score"], r["paths"][0][0]), reverse=True) @@ -420,9 +426,7 @@ def find_root_causes( for entry in results: root_cause_nodes.append((entry["root_cause"], entry["score"])) root_cause_paths[entry["root_cause"]] = entry["paths"] - return RCAResults( - root_cause_nodes=root_cause_nodes, - root_cause_paths=root_cause_paths) + return RCAResults(root_cause_nodes=root_cause_nodes, root_cause_paths=root_cause_paths) def save(self, directory, filename="bn", **kwargs): writer = BIFWriter(self.bayesian_model) diff --git a/pyrca/analyzers/epsilon_diagnosis.py b/pyrca/analyzers/epsilon_diagnosis.py index f437fb6..feea2a9 100644 --- a/pyrca/analyzers/epsilon_diagnosis.py +++ b/pyrca/analyzers/epsilon_diagnosis.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The epsilon-Diagnosis algorithm. """ @@ -19,6 +24,7 @@ class EpsilonDiagnosisConfig(BaseConfig): :param bootstrap_time: Bootstrap times. :param root_cause_top_k: The maximum number of root causes in the results. """ + alpha: float = 0.05 bootstrap_time: int = 200 root_cause_top_k: int = 3 @@ -30,17 +36,14 @@ class EpsilonDiagnosis(BaseRCA): epsilon-Diagnosis: Unsupervised and Real-time Diagnosis of Small window Long-tail Latency in Large-scale Microservice Platforms. """ + config_class = EpsilonDiagnosisConfig def __init__(self, config: EpsilonDiagnosisConfig): super().__init__() self.config = config - def train( - self, - normal_df: pd.DataFrame, - **kwargs - ): + def train(self, normal_df: pd.DataFrame, **kwargs): """ Two variable correlation analysis given the training time series. @@ -53,8 +56,9 @@ def _samples(array, times=50): # bootstrapping to calculate the p-value normal_df_msample = np.apply_along_axis(_samples, 0, normal_df.values, times=self.config.bootstrap_time) - normal_correlations = np.empty((int((normal_df_msample.shape[1] * (normal_df_msample.shape[1] - 1) / 2)), - normal_df_msample.shape[2])) + normal_correlations = np.empty( + (int((normal_df_msample.shape[1] * (normal_df_msample.shape[1] - 1) / 2)), normal_df_msample.shape[2]) + ) for k in range(normal_df_msample.shape[2]): cov_matrix = np.cov(normal_df_msample[:, :, k].T) idx = 0 @@ -64,13 +68,11 @@ def _samples(array, times=50): else: normal_correlations[idx, k] = np.square(cov_matrix[i, j]) / (cov_matrix[i, i] * cov_matrix[j, j]) idx += 1 - self.statistics = dict(zip(normal_df.columns, np.apply_along_axis(np.quantile, 0, normal_correlations, q=1-self.config.alpha))) + self.statistics = dict( + zip(normal_df.columns, np.apply_along_axis(np.quantile, 0, normal_correlations, q=1 - self.config.alpha)) + ) - def find_root_causes( - self, - abnormal_df: pd.DataFrame, - **kwargs - ): + def find_root_causes(self, abnormal_df: pd.DataFrame, **kwargs): """ Finds the root causes given the abnormal dataset. @@ -83,10 +85,10 @@ def find_root_causes( if np.var(self.normal_df[colname].values) == 0 or np.var(abnormal_df[colname].values) == 0: self.correlations[colname] = 0 else: - self.correlations[colname] = np.square(np.cov(self.normal_df[colname].values, abnormal_df[colname].values)[0,1]) \ - / (np.var(self.normal_df[colname].values) * np.var(abnormal_df[colname].values)) + self.correlations[colname] = np.square( + np.cov(self.normal_df[colname].values, abnormal_df[colname].values)[0, 1] + ) / (np.var(self.normal_df[colname].values) * np.var(abnormal_df[colname].values)) if self.correlations[colname] > self.statistics[colname]: root_cause_nodes.append((colname, self.correlations[colname])) - root_cause_nodes = sorted(root_cause_nodes, key=lambda r: r[1], reverse=True)[:self.config.root_cause_top_k] - return RCAResults( - root_cause_nodes=root_cause_nodes) + root_cause_nodes = sorted(root_cause_nodes, key=lambda r: r[1], reverse=True)[: self.config.root_cause_top_k] + return RCAResults(root_cause_nodes=root_cause_nodes) diff --git a/pyrca/analyzers/psi_pc.py b/pyrca/analyzers/psi_pc.py index 24ded68..8e7a38e 100644 --- a/pyrca/analyzers/psi_pc.py +++ b/pyrca/analyzers/psi_pc.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The Phi-PC algorithm """ @@ -14,6 +19,7 @@ from pyrca.thirdparty.causallearn.utils.cit import CIT from pyrca.thirdparty.rcd import rcd + @dataclass class PsiPCConfig(BaseConfig): """ @@ -29,6 +35,7 @@ class PsiPCConfig(BaseConfig): :param f_node: name of anomaly variable. :param verbose: True iff verbose output should be printed. Default: False. """ + start_alpha: float = 0.01 alpha_step: float = 0.1 alpha_limit: float = 1 @@ -36,7 +43,7 @@ class PsiPCConfig(BaseConfig): gamma: int = 5 bins: int = 5 k: int = None - f_node: str = 'F-node' + f_node: str = "F-node" verbose: bool = False ci_test: CIT = chisq @@ -47,6 +54,7 @@ class PsiPC(BaseRCA): Root Cause Analysis of Failures in Microservices through Causal Discovery """ + config_class = PsiPCConfig def __init__(self, config: PsiPCConfig): @@ -59,18 +67,11 @@ def train(self, **kwargs): """ pass - def find_root_causes( - self, - normal_df: pd.DataFrame, - abnormal_df: pd.DataFrame, - **kwargs - ): + def find_root_causes(self, normal_df: pd.DataFrame, abnormal_df: pd.DataFrame, **kwargs): """ Finds the root causes given the abnormal dataset. :return: A list of the found root causes. """ result, _ = rcd.run_multi_phase(normal_df, abnormal_df, self.config.to_dict()) - root_cause_nodes =[(key, None) for key in result[:self.config.k]] - return RCAResults( - root_cause_nodes=root_cause_nodes) - + root_cause_nodes = [(key, None) for key in result[: self.config.k]] + return RCAResults(root_cause_nodes=root_cause_nodes) diff --git a/pyrca/analyzers/random_walk.py b/pyrca/analyzers/random_walk.py index 9f124ec..388b347 100644 --- a/pyrca/analyzers/random_walk.py +++ b/pyrca/analyzers/random_walk.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The RCA method based on random walk """ @@ -26,6 +31,7 @@ class RandomWalkConfig(BaseConfig): :param num_repeats: The number of random walk runs. :param root_cause_top_k: The maximum number of root causes in the results. """ + graph: Union[pd.DataFrame, str] = None use_partial_corr: bool = False rho: float = 0.1 @@ -38,6 +44,7 @@ class RandomWalk(BaseRCA): """ The RCA method based on random walk on the topology/causal graph. """ + config_class = RandomWalkConfig def __init__(self, config: RandomWalkConfig): @@ -50,9 +57,7 @@ def __init__(self, config: RandomWalkConfig): with open(config.graph, "rb") as f: graph = pickle.load(f) else: - raise RuntimeError( - "The graph file format is not supported, " - "please choose a csv or pickle file.") + raise RuntimeError("The graph file format is not supported, " "please choose a csv or pickle file.") else: graph = config.graph self.adjacency_mat = graph @@ -62,6 +67,7 @@ def __init__(self, config: RandomWalkConfig): @staticmethod def _partial_correlation(df: pd.DataFrame, x, y, z, add_noise=True): import pingouin as pg + if add_noise: noise = np.random.normal(0, 1e-8, size=df.shape) df = df.add(pd.DataFrame(noise, columns=df.columns, index=df.index)) @@ -70,8 +76,7 @@ def _partial_correlation(df: pd.DataFrame, x, y, z, add_noise=True): if not np.isnan(p_val): return corr["r"].values[0] else: - raise RuntimeError("The p-value for partial correlation is NaN, " - "please add more data points.") + raise RuntimeError("The p-value for partial correlation is NaN, " "please add more data points.") @staticmethod def _correlation(df: pd.DataFrame, x, y): @@ -86,8 +91,11 @@ def _compute_weight(self, df, anomaly, metric): if self.use_partial_corr: z = list(self.graph.predecessors(anomaly)) + list(self.graph.predecessors(metric)) ps = list(set([p for p in z if p != metric and p != anomaly])) - weight = abs(self._partial_correlation(df, anomaly, metric, ps)) if len(ps) > 0 \ + weight = ( + abs(self._partial_correlation(df, anomaly, metric, ps)) + if len(ps) > 0 else abs(self._correlation(df, anomaly, metric)) + ) else: weight = abs(self._correlation(df, anomaly, metric)) return weight @@ -155,6 +163,7 @@ def _random_walk(graph, start, num_steps, num_repeats, random_seed=0): @staticmethod def _find_all_paths(graph, u, v): from collections import deque + q, paths = deque([(u, [])]), [] while q: node, path = q.popleft() @@ -197,12 +206,7 @@ def train(self, **kwargs): """ pass - def find_root_causes( - self, - anomalous_metrics: Union[List, Dict], - df: pd.DataFrame, - **kwargs - ) -> RCAResults: + def find_root_causes(self, anomalous_metrics: Union[List, Dict], df: pd.DataFrame, **kwargs) -> RCAResults: """ Finds the root causes given the observed anomalous metrics. @@ -217,8 +221,11 @@ def find_root_causes( graph = self._build_weighted_graph(df, anomalous_metrics, self.config.rho) counts = { anomaly: self._random_walk( - graph, anomaly, self.config.num_steps, self.config.num_repeats, - random_seed=kwargs.get("random_seed", None) + graph, + anomaly, + self.config.num_steps, + self.config.num_repeats, + random_seed=kwargs.get("random_seed", None), ) for anomaly in anomalous_metrics } @@ -244,6 +251,4 @@ def find_root_causes( paths.append((path_score, [(node, None) for node in nodes])) root_cause_paths[root] = paths - return RCAResults( - root_cause_nodes=root_cause_nodes, - root_cause_paths=root_cause_paths) + return RCAResults(root_cause_nodes=root_cause_nodes, root_cause_paths=root_cause_paths) diff --git a/pyrca/analyzers/rht.py b/pyrca/analyzers/rht.py index 0c2085d..fcb5faa 100644 --- a/pyrca/analyzers/rht.py +++ b/pyrca/analyzers/rht.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The RHT algorithm """ @@ -24,6 +29,7 @@ class RHTConfig(BaseConfig): :param aggregator: The function for aggregating the node score from all the abnormal data. :param root_cause_top_k: The maximum number of root causes in the results. """ + graph: Union[pd.DataFrame, str] = None aggregator: str = "max" root_cause_top_k: int = 3 @@ -35,6 +41,7 @@ class RHT(BaseRCA): Causal Inference-Based Root Cause Analysis for Online Service Systems with Intervention Recognition. """ + config_class = RHTConfig def __init__(self, config: RHTConfig): @@ -47,9 +54,7 @@ def __init__(self, config: RHTConfig): with open(config.graph, "rb") as f: graph = pickle.load(f) else: - raise RuntimeError( - "The graphs file format is not supported, " - "please choose a csv or pickle file.") + raise RuntimeError("The graphs file format is not supported, " "please choose a csv or pickle file.") else: graph = config.graph self.adjacency_mat = graph @@ -67,11 +72,7 @@ def _get_aggregator(name): else: raise f"Unknown aggregator {name}" - def train( - self, - normal_df: pd.DataFrame, - **kwargs - ): + def train(self, normal_df: pd.DataFrame, **kwargs): """ Train regression model for each node based on its parents. Build the score functions. @@ -94,11 +95,7 @@ def train( self.regressors_dict[node] = [None, scaler] def find_root_causes( - self, - abnormal_df: pd.DataFrame, - anomalous_metrics: str = None, - adjustment: bool =False, - **kwargs + self, abnormal_df: pd.DataFrame, anomalous_metrics: str = None, adjustment: bool = False, **kwargs ): """ Finds the root causes given the abnormal dataset. @@ -140,7 +137,7 @@ def find_root_causes( # node_scores[key][1] indicates the confidence root_cause_nodes = [(key, node_scores[key][0]) for key in node_scores] - root_cause_nodes = sorted(root_cause_nodes, key=lambda r: r[1], reverse=True)[:self.config.root_cause_top_k] + root_cause_nodes = sorted(root_cause_nodes, key=lambda r: r[1], reverse=True)[: self.config.root_cause_top_k] root_cause_paths = {} if anomalous_metrics is not None: @@ -150,6 +147,4 @@ def find_root_causes( except nx.exception.NetworkXNoPath: path = None root_cause_paths[root_cause_nodes[idx][0]] = path - return RCAResults( - root_cause_nodes=root_cause_nodes, - root_cause_paths=root_cause_paths) + return RCAResults(root_cause_nodes=root_cause_nodes, root_cause_paths=root_cause_paths) diff --git a/pyrca/applications/example/dataset.py b/pyrca/applications/example/dataset.py index c60c9d9..316daf6 100644 --- a/pyrca/applications/example/dataset.py +++ b/pyrca/applications/example/dataset.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import pandas as pd diff --git a/pyrca/applications/example/rca.py b/pyrca/applications/example/rca.py index 6534c3c..867eca2 100644 --- a/pyrca/applications/example/rca.py +++ b/pyrca/applications/example/rca.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import yaml import pickle @@ -8,7 +13,6 @@ class ConfigParser: - def __init__(self, file_path): directory = os.path.dirname(os.path.abspath(__file__)) if file_path is None: @@ -25,14 +29,12 @@ def get_parameters(self, name): class RCAEngine: - def __init__(self, model_dir=None, logger=None): self.adjacency_df_filename = "adjacency_df.pkl" self.bn_filename = "bn" if model_dir is None: - model_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "models") + model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") if not os.path.exists(model_dir): os.makedirs(model_dir) self.model_dir = model_dir @@ -47,18 +49,18 @@ def __init__(self, model_dir=None, logger=None): "bn": self._find_root_causes_bn, "bayesian": self._find_root_causes_bn, "rw": self._find_root_causes_rw, - "random_walk": self._find_root_causes_rw + "random_walk": self._find_root_causes_rw, } def build_causal_graph( - self, - df, - domain_knowledge_file=None, - run_pdag2dag=True, - max_num_points=5000000, - method_class=None, - verbose=False, - **kwargs + self, + df, + domain_knowledge_file=None, + run_pdag2dag=True, + max_num_points=5000000, + method_class=None, + verbose=False, + **kwargs, ): from pyrca.graphs.causal.pc import PC @@ -66,30 +68,29 @@ def build_causal_graph( if verbose: self.logger.info(f"The shape of the training data for infer_causal_graph: {df.shape}") if domain_knowledge_file is None: - domain_knowledge_file = os.path.join(os.path.dirname( - os.path.abspath(__file__)), "configs/domain_knowledge.yaml") + domain_knowledge_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "configs/domain_knowledge.yaml" + ) if method_class is None: method_class = PC - model = method_class(method_class.config_class( - domain_knowledge_file=domain_knowledge_file, - run_pdag2dag=run_pdag2dag, - max_num_points=max_num_points, - **kwargs)) + model = method_class( + method_class.config_class( + domain_knowledge_file=domain_knowledge_file, + run_pdag2dag=run_pdag2dag, + max_num_points=max_num_points, + **kwargs, + ) + ) adjacency_df = model.train(df) adjacency_df.to_pickle(os.path.join(self.model_dir, self.adjacency_df_filename)) PC.dump_to_tetrad_json(adjacency_df, self.model_dir) return adjacency_df - def train_bayesian_network( - self, - dfs, - domain_knowledge_file=None, - config_file=None, - verbose=False - ): + def train_bayesian_network(self, dfs, domain_knowledge_file=None, config_file=None, verbose=False): from pyrca.utils.domain import DomainParser from pyrca.analyzers.bayesian import BayesianNetwork, BayesianNetworkConfig + if isinstance(dfs, pd.DataFrame): assert dfs.shape[0] > 10000, "The length of df is less than 10000." if verbose: @@ -101,8 +102,9 @@ def train_bayesian_network( self.logger.info(f"The training data shape for the Bayesian network: {(n, dfs[0].shape[1])}") if domain_knowledge_file is None: - domain_knowledge_file = os.path.join(os.path.dirname( - os.path.abspath(__file__)), "configs/domain_knowledge.yaml") + domain_knowledge_file = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "configs/domain_knowledge.yaml" + ) domain = DomainParser(domain_knowledge_file) config = ConfigParser(config_file) @@ -117,19 +119,20 @@ def train_bayesian_network( default_sigma=params["default_sigma"], thres_win_size=params["thres_win_size"], thres_reduce_func=params["thres_reduce_func"], - sigmas=params.get("sigmas", {}) - )) + sigmas=params.get("sigmas", {}), + ) + ) bayesian_network.train(dfs=dfs) bayesian_network.add_root_causes(domain.get_root_causes()) bayesian_network.save(self.model_dir, name=self.bn_filename) return bayesian_network def train_detector( - self, - df: Union[pd.DataFrame, Dict], - config_file: Optional[str] = None, - use_separate_models=True, - additional_config: Dict = None + self, + df: Union[pd.DataFrame, Dict], + config_file: Optional[str] = None, + use_separate_models=True, + additional_config: Dict = None, ) -> Dict: """ Trains the detector(s) given the time series data. @@ -142,6 +145,7 @@ def train_detector( :return: The train model for each metric. """ from pyrca.outliers.stats import StatsDetector, StatsDetectorConfig + if isinstance(df, dict): df = pd.DataFrame.from_dict(df) assert df.shape[0] > 5000, "The length of df is less than 5000." @@ -175,12 +179,14 @@ def train_detector( def _find_root_causes_rw(self, df, anomalies, **kwargs): from pyrca.analyzers.random_walk import RandomWalk, RandomWalkConfig + graph = pd.read_pickle(os.path.join(self.model_dir, self.adjacency_df_filename)) model = RandomWalk(RandomWalkConfig(graph=graph)) return model.find_root_causes(anomalies, df=df, **kwargs) def _find_root_causes_bn(self, df, anomalies, **kwargs): from pyrca.analyzers.bayesian import BayesianNetwork + try: model = BayesianNetwork.load(self.model_dir, self.bn_filename) except: @@ -197,12 +203,12 @@ def find_root_causes_bn(self, anomalies: List, **kwargs): return self._find_root_causes_bn(None, anomalies, **kwargs).to_list() def find_root_causes( - self, - df: Union[pd.DataFrame, Dict], - detector: Union[Dict, BaseDetector], - rca_method: Optional[str] = None, - known_anomalies: List = None, - **kwargs + self, + df: Union[pd.DataFrame, Dict], + detector: Union[Dict, BaseDetector], + rca_method: Optional[str] = None, + known_anomalies: List = None, + **kwargs, ): """ Finds the potential root causes given an incident window. @@ -232,8 +238,10 @@ def find_root_causes( anomaly_info = DetectionResults.merge(results).to_dict() if rca_method is not None: - anomalies = anomaly_info["anomalous_metrics"] if known_anomalies is None \ + anomalies = ( + anomaly_info["anomalous_metrics"] + if known_anomalies is None else list(set(anomaly_info["anomalous_metrics"] + known_anomalies)) - anomaly_info["root_causes"] = self._rca_methods[rca_method]( - df=df, anomalies=anomalies, **kwargs) + ) + anomaly_info["root_causes"] = self._rca_methods[rca_method](df=df, anomalies=anomalies, **kwargs) return anomaly_info diff --git a/pyrca/base.py b/pyrca/base.py index 6d5ddc9..8ad1831 100644 --- a/pyrca/base.py +++ b/pyrca/base.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ Base class for all Models. """ @@ -45,8 +50,7 @@ def from_dict(cls, d: Dict): """ config = cls() for key, value in d.items(): - assert key in config.__dict__, \ - f"Class {cls.__name__} has no field named {key}." + assert key in config.__dict__, f"Class {cls.__name__} has no field named {key}." setattr(config, key, value) return config @@ -104,12 +108,7 @@ def __setstate__(self, state): for name, value in state.items(): setattr(self, name, value) - def save( - self, - directory: str, - filename: str = None, - **kwargs - ): + def save(self, directory: str, filename: str = None, **kwargs): """ Saves the initialized model. @@ -127,12 +126,7 @@ def save( dill.dump(state, f) @classmethod - def load( - cls, - directory: str, - filename: str = None, - **kwargs - ): + def load(cls, directory: str, filename: str = None, **kwargs): """ Loads the dumped model. diff --git a/pyrca/graphs/__init__.py b/pyrca/graphs/__init__.py index e69de29..d0e4276 100644 --- a/pyrca/graphs/__init__.py +++ b/pyrca/graphs/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/graphs/causal/__init__.py b/pyrca/graphs/causal/__init__.py index 7f40c26..8d959f3 100644 --- a/pyrca/graphs/causal/__init__.py +++ b/pyrca/graphs/causal/__init__.py @@ -1,16 +1,12 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from .pc import PC, PCConfig from .ges import GES, GESConfig from .fges import FGES, FGESConfig from .lingam import LiNGAM, LiNGAMConfig -__all__ = [ - "PC", - "PCConfig", - "GES", - "GESConfig", - "FGES", - "FGESConfig", - "LiNGAM", - "LiNGAMConfig" -] +__all__ = ["PC", "PCConfig", "GES", "GESConfig", "FGES", "FGESConfig", "LiNGAM", "LiNGAMConfig"] diff --git a/pyrca/graphs/causal/base.py b/pyrca/graphs/causal/base.py index 27dd99f..2877cc1 100644 --- a/pyrca/graphs/causal/base.py +++ b/pyrca/graphs/causal/base.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The base class for causal discovery methods. """ @@ -23,13 +28,13 @@ class CausalModelConfig(BaseConfig): :param run_pdag2dag: Whether to convert a partial DAG to a DAG. :param max_num_points: The maximum number of data points in causal discovery. """ + domain_knowledge_file: str = None run_pdag2dag: bool = True max_num_points: int = 5000000 class CausalModel(BaseModel): - @staticmethod def initialize(): pass @@ -39,45 +44,28 @@ def finish(): pass @abstractmethod - def _train( - self, - df: pd.DataFrame, - forbids: List, - requires: List, - **kwargs - ): + def _train(self, df: pd.DataFrame, forbids: List, requires: List, **kwargs): raise NotImplementedError - def train( - self, - df: pd.DataFrame, - **kwargs - ) -> pd.DataFrame: + def train(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame: """ Builds the causal graph given the training dataset. :param df: The training dataset. :return: The adjacency matrix. """ - df = df.iloc[:self.config.max_num_points] \ - if self.config.max_num_points is not None else df + df = df.iloc[: self.config.max_num_points] if self.config.max_num_points is not None else df parser = DomainParser(self.config.domain_knowledge_file) adjacency_df = self._train( - df=df, - forbids=parser.get_forbid_links(df.columns), - requires=parser.get_require_links(), - **kwargs + df=df, forbids=parser.get_forbid_links(df.columns), requires=parser.get_require_links(), **kwargs ) var_names = adjacency_df.columns if self.config.run_pdag2dag: dag, flag = CausalModel.pdag2dag(adjacency_df.values) if flag is False: raise RuntimeError("Orientation of the undirected edges failed.") - adjacency_df = pd.DataFrame( - {var_names[i]: dag[:, i] for i in range(len(var_names))}, - index=var_names - ) + adjacency_df = pd.DataFrame({var_names[i]: dag[:, i] for i in range(len(var_names))}, index=var_names) return adjacency_df @staticmethod @@ -131,10 +119,7 @@ def find_sink(g): return r, True @staticmethod - def check_cycles( - adjacency_df: pd.DataFrame, - direct_only: bool = False - ) -> List: + def check_cycles(adjacency_df: pd.DataFrame, direct_only: bool = False) -> List: """ Checks if the generated causal graph has cycles. @@ -170,8 +155,7 @@ def get_parents(adjacency_df: pd.DataFrame) -> Dict: """ var_names = adjacency_df.columns graph = adjacency_df.values - parents = {name: [var_names[j] for j, v in enumerate(graph[:, i]) if v > 0] - for i, name in enumerate(var_names)} + parents = {name: [var_names[j] for j, v in enumerate(graph[:, i]) if v > 0] for i, name in enumerate(var_names)} return parents @staticmethod @@ -184,16 +168,13 @@ def get_children(adjacency_df: pd.DataFrame) -> Dict: """ var_names = adjacency_df.columns graph = adjacency_df.values - children = {name: [var_names[j] for j, v in enumerate(graph[i, :]) if v > 0] - for i, name in enumerate(var_names)} + children = { + name: [var_names[j] for j, v in enumerate(graph[i, :]) if v > 0] for i, name in enumerate(var_names) + } return children @staticmethod - def dump_to_tetrad_json( - adjacency_df: pd.DataFrame, - output_dir: str, - filename: str = "graph.json" - ): + def dump_to_tetrad_json(adjacency_df: pd.DataFrame, output_dir: str, filename: str = "graph.json"): """ Dumps the graph into a Tetrad format. @@ -225,9 +206,7 @@ def dump_to_tetrad_json( "namesHash": {}, "pattern": False, "pag": False, - "attributes": { - "BIC": 0.0 - } + "attributes": {"BIC": 0.0}, } for node in var_names: r = { @@ -236,7 +215,7 @@ def dump_to_tetrad_json( "centerX": 100, "centerY": 100, "attributes": {}, - "name": node + "name": node, } graph["nodes"].append(r) graph["namesHash"][node] = r @@ -251,7 +230,7 @@ def dump_to_tetrad_json( "centerX": 100, "centerY": 100, "attributes": {}, - "name": a + "name": a, }, "node2": { "nodeType": {"ordinal": 0}, @@ -259,13 +238,13 @@ def dump_to_tetrad_json( "centerX": 100, "centerY": 100, "attributes": {}, - "name": b + "name": b, }, "endpoint1": {"ordinal": 0}, "endpoint2": {"ordinal": e}, "bold": False, "properties": [], - "edgeTypeProbabilities": [] + "edgeTypeProbabilities": [], } graph["edgesSet"].append(r) graph["edgeLists"][a].append(r) @@ -299,10 +278,7 @@ def load_from_tetrad_json(filepath: str) -> pd.DataFrame: if endpoint2 == 0: mat[name_to_index[b]][name_to_index[a]] = 1 - return pd.DataFrame( - {var_names[i]: mat[:, i] for i in range(len(var_names))}, - index=var_names - ) + return pd.DataFrame({var_names[i]: mat[:, i] for i in range(len(var_names))}, index=var_names) @staticmethod def plot_causal_graph_networkx(adjacency_df): @@ -319,7 +295,7 @@ def plot_causal_graph_networkx(adjacency_df): for x in sorted(nx.simple_cycles(graph)): flag = True for i in range(1, len(x)): - a, b = node2idx[x[i-1]], node2idx[x[i]] + a, b = node2idx[x[i - 1]], node2idx[x[i]] if adjacency_mat[a][b] == 1 and adjacency_mat[b][a] == 1: flag = False break @@ -331,9 +307,9 @@ def plot_causal_graph_networkx(adjacency_df): nx.draw_networkx_edges( graph, pos, - arrowstyle='->', + arrowstyle="->", arrowsize=15, - edge_color='c', + edge_color="c", width=1.5, ) nx.draw_networkx_labels(graph, pos, labels={c: c for c in adjacency_df.columns}) diff --git a/pyrca/graphs/causal/fges.py b/pyrca/graphs/causal/fges.py index 752ad45..0a49da9 100644 --- a/pyrca/graphs/causal/fges.py +++ b/pyrca/graphs/causal/fges.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The fast greedy equivalence search (FGES) algorithm. """ @@ -20,6 +25,7 @@ class FGESConfig(CausalModelConfig): :param penalty_discount: The penalty discount (a regularization parameter). :param score_id: The score function name, e.g., "sem-bic-score". """ + domain_knowledge_file: str = None run_pdag2dag: bool = True max_num_points: int = 5000000 @@ -32,6 +38,7 @@ class FGES(CausalModel): """ The fast greedy equivalence search (FGES) algorithm for causal discovery. """ + config_class = FGESConfig causal = None @@ -41,6 +48,7 @@ def __init__(self, config: FGESConfig): @staticmethod def initialize(): from pycausal.pycausal import pycausal as pc + FGES.causal = pc() FGES.causal.start_vm() @@ -48,18 +56,12 @@ def initialize(): def finish(): FGES.causal.stop_vm() - def _train( - self, - df: pd.DataFrame, - forbids: List, - requires: List, - start_vm: bool = True, - **kwargs - ): + def _train(self, df: pd.DataFrame, forbids: List, requires: List, start_vm: bool = True, **kwargs): from ...utils.misc import is_pycausal_available - assert is_pycausal_available(), \ - "pycausal is not installed. Please install it from github repo: " \ - "https://github.com/bd2kccd/py-causal." + + assert is_pycausal_available(), ( + "pycausal is not installed. Please install it from github repo: " "https://github.com/bd2kccd/py-causal." + ) from pycausal import search, prior from pycausal.pycausal import pycausal as pc @@ -84,28 +86,25 @@ def _train( faithfulnessAssumed=True, symmetricFirstStep=False, penaltyDiscount=self.config.penalty_discount, - verbose=False + verbose=False, ) for edge in tetrad.getEdges(): - if edge == '': + if edge == "": continue items = edge.split() assert len(items) == 3 a, b = str(items[0]), str(items[2]) - if items[1] == '-->': + if items[1] == "-->": graph[column_name2idx[a], column_name2idx[b]] = 1 - elif items[1] == '---': + elif items[1] == "---": graph[column_name2idx[a], column_name2idx[b]] = 1 graph[column_name2idx[b], column_name2idx[a]] = 1 else: - raise ValueError('Unknown direction: {}'.format(items[1])) + raise ValueError("Unknown direction: {}".format(items[1])) if start_vm and FGES.causal is None: causal.stop_vm() adjacency_mat = graph.astype(int) np.fill_diagonal(adjacency_mat, 0) - adjacency_df = pd.DataFrame( - {var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, - index=var_names - ) + adjacency_df = pd.DataFrame({var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, index=var_names) return adjacency_df diff --git a/pyrca/graphs/causal/ges.py b/pyrca/graphs/causal/ges.py index aa29c08..5462ece 100644 --- a/pyrca/graphs/causal/ges.py +++ b/pyrca/graphs/causal/ges.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The greedy equivalence search (GES) algorithm. """ @@ -19,6 +24,7 @@ class GESConfig(CausalModelConfig): :param max_degree: The allowed maximum number of parents when searching the graph. :param penalty_discount: The penalty discount (a regularization parameter). """ + domain_knowledge_file: str = None run_pdag2dag: bool = True max_num_points: int = 5000000 @@ -30,18 +36,13 @@ class GES(CausalModel): """ The greedy equivalence search (GES) algorithm for causal discovery. """ + config_class = GESConfig def __init__(self, config: GESConfig): self.config = config - def _train( - self, - df: pd.DataFrame, - forbids: List, - requires: List, - **kwargs - ): + def _train(self, df: pd.DataFrame, forbids: List, requires: List, **kwargs): from pyrca.thirdparty.causallearn.search.ScoreBased.GES import ges from pyrca.thirdparty.causallearn.utils.BackgroundKnowledge import BackgroundKnowledge @@ -64,28 +65,25 @@ def _train( maxP=self.config.max_degree, parameters={"kfold": 10, "lambda": self.config.penalty_discount}, background_knowledge=prior, - verbose=False + verbose=False, ) for edge in res["G"].get_graph_edges(): edge = str(edge) - if edge == '': + if edge == "": continue items = edge.split() assert len(items) == 3 a = int(str(items[0]).lower().replace("x", "")) - 1 b = int(str(items[2]).lower().replace("x", "")) - 1 - if items[1] == '-->': + if items[1] == "-->": graph[a, b] = 1 - elif items[1] == '---': + elif items[1] == "---": graph[a, b] = 1 graph[b, a] = 1 else: - raise ValueError('Unknown direction: {}'.format(items[1])) + raise ValueError("Unknown direction: {}".format(items[1])) adjacency_mat = graph.astype(int) np.fill_diagonal(adjacency_mat, 0) - adjacency_df = pd.DataFrame( - {var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, - index=var_names - ) + adjacency_df = pd.DataFrame({var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, index=var_names) return adjacency_df diff --git a/pyrca/graphs/causal/lingam.py b/pyrca/graphs/causal/lingam.py index 47377e5..b85199b 100644 --- a/pyrca/graphs/causal/lingam.py +++ b/pyrca/graphs/causal/lingam.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The non-gaussian linear causal models (LiNGAM). """ @@ -21,6 +26,7 @@ class LiNGAMConfig(CausalModelConfig): bootstrapping will be applied. :param min_causal_effect: The threshold for detecting causal direction (for bootstrapping only). """ + domain_knowledge_file: str = None run_pdag2dag: bool = True max_num_points: int = 5000000 @@ -33,18 +39,13 @@ class LiNGAM(CausalModel): """ The non-gaussian linear causal models (LiNGAM): https://github.com/cdt15/lingam. """ + config_class = LiNGAMConfig def __init__(self, config: LiNGAMConfig): self.config = config - def _train( - self, - df: pd.DataFrame, - forbids: List, - requires: List, - **kwargs - ): + def _train(self, df: pd.DataFrame, forbids: List, requires: List, **kwargs): import lingam from lingam.utils import make_prior_knowledge @@ -69,8 +70,5 @@ def _train( adjacency_mat = (prob >= self.config.lower_limit).astype(int).T np.fill_diagonal(adjacency_mat, 0) - adjacency_df = pd.DataFrame( - {var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, - index=var_names - ) + adjacency_df = pd.DataFrame({var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, index=var_names) return adjacency_df diff --git a/pyrca/graphs/causal/pc.py b/pyrca/graphs/causal/pc.py index f5c2de5..f7408f7 100644 --- a/pyrca/graphs/causal/pc.py +++ b/pyrca/graphs/causal/pc.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The PC algorithm. """ @@ -18,6 +23,7 @@ class PCConfig(CausalModelConfig): :param max_num_points: The maximum number of data points in causal discovery. :param alpha: The p-value threshold for independent test. """ + domain_knowledge_file: str = None run_pdag2dag: bool = True max_num_points: int = 5000000 @@ -28,18 +34,13 @@ class PC(CausalModel): """ The standard PC algorithm. """ + config_class = PCConfig def __init__(self, config: PCConfig): self.config = config - def _train( - self, - df: pd.DataFrame, - forbids: List, - requires: List, - **kwargs - ): + def _train(self, df: pd.DataFrame, forbids: List, requires: List, **kwargs): from pyrca.thirdparty.causallearn.search.ConstraintBased.PC import pc from pyrca.thirdparty.causallearn.utils.BackgroundKnowledge import BackgroundKnowledge @@ -55,32 +56,24 @@ def _train( for a, b in requires: prior.add_required_by_node(a, b) - res = pc( - df.values, - alpha=self.config.alpha, - node_names=list(df.columns), - background_knowledge=prior - ) + res = pc(df.values, alpha=self.config.alpha, node_names=list(df.columns), background_knowledge=prior) for edge in res.G.get_graph_edges(): edge = str(edge) - if edge == '': + if edge == "": continue items = edge.split() assert len(items) == 3 a, b = str(items[0]), str(items[2]) - if items[1] == '-->': + if items[1] == "-->": graph[column_name2index[a], column_name2index[b]] = 1 - elif items[1] == '---': + elif items[1] == "---": graph[column_name2index[a], column_name2index[b]] = 1 graph[column_name2index[b], column_name2index[a]] = 1 else: - raise ValueError('Unknown direction: {}'.format(items[1])) + raise ValueError("Unknown direction: {}".format(items[1])) adjacency_mat = graph.astype(int) np.fill_diagonal(adjacency_mat, 0) - adjacency_df = pd.DataFrame( - {var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, - index=var_names - ) + adjacency_df = pd.DataFrame({var_names[i]: adjacency_mat[:, i] for i in range(len(var_names))}, index=var_names) return adjacency_df diff --git a/pyrca/outliers/__init__.py b/pyrca/outliers/__init__.py index e69de29..d0e4276 100644 --- a/pyrca/outliers/__init__.py +++ b/pyrca/outliers/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/outliers/base.py b/pyrca/outliers/base.py index 0c302d6..b123391 100644 --- a/pyrca/outliers/base.py +++ b/pyrca/outliers/base.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ Base classes for all outliers. """ @@ -17,6 +22,7 @@ class BaseDetector(BaseModel): Base class for Outlier (Anomaly) Detectors. This class should not be used directly, Use dervied class instead. """ + config_class = None def __init__(self): @@ -87,29 +93,33 @@ class DetectorMixin: """ Check data quality and train """ + @staticmethod def _check_nan(df, **kwargs): - assert not bool(df.isnull().values.any()), \ - "The input dataframe contains NaNs." + assert not bool(df.isnull().values.any()), "The input dataframe contains NaNs." @staticmethod def _check_column_names(df, **kwargs): for col in df.columns: - assert isinstance(col, str), \ - f"The column name must be a string instead of {type(col)}." + assert isinstance(col, str), f"The column name must be a string instead of {type(col)}." assert " " not in col, f"The column name cannot contains a SPACE: {col}." assert "#" not in col, f"The column name cannot contains #: {col}." @staticmethod def _check_length(df, min_length=3000, **kwargs): - assert len(df) >= min_length, \ - f"The number of data points is less than {min_length}." + assert len(df) >= min_length, f"The number of data points is less than {min_length}." @staticmethod def _check_data_type(df, **kwargs): for t in df.dtypes: - assert t in [np.int, np.int32, np.int64, np.float, np.float32, np.float64], \ - f"The data type {t} is not int or float." + assert t in [ + np.int, + np.int32, + np.int64, + np.float, + np.float32, + np.float64, + ], f"The data type {t} is not int or float." def check_data_and_train(self, df, **kwargs): """ @@ -131,6 +141,7 @@ class DetectionResults: """ The class for storing anomaly detection results. """ + anomalous_metrics: list = field(default_factory=lambda: []) anomaly_timestamps: dict = field(default_factory=lambda: {}) anomaly_labels: dict = field(default_factory=lambda: {}) @@ -151,12 +162,8 @@ def merge(cls, results: list): :return: The merged ``DetectionResults`` object. """ res = DetectionResults() - res.anomalous_metrics = list(itertools.chain( - *[r.anomalous_metrics for r in results])) - res.anomaly_timestamps = dict(itertools.chain( - *[r.anomaly_timestamps.items() for r in results])) - res.anomaly_labels = dict(itertools.chain( - *[r.anomaly_labels.items() for r in results])) - res.anomaly_info = dict(itertools.chain( - *[r.anomaly_info.items() for r in results])) + res.anomalous_metrics = list(itertools.chain(*[r.anomalous_metrics for r in results])) + res.anomaly_timestamps = dict(itertools.chain(*[r.anomaly_timestamps.items() for r in results])) + res.anomaly_labels = dict(itertools.chain(*[r.anomaly_labels.items() for r in results])) + res.anomaly_info = dict(itertools.chain(*[r.anomaly_info.items() for r in results])) return res diff --git a/pyrca/outliers/stats.py b/pyrca/outliers/stats.py index 563885e..660bb81 100644 --- a/pyrca/outliers/stats.py +++ b/pyrca/outliers/stats.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# """ The statistical-based anomaly detector. """ @@ -10,8 +15,7 @@ from dataclasses import dataclass from pyrca.base import BaseConfig -from pyrca.outliers.base import BaseDetector, \ - DetectorMixin, DetectionResults +from pyrca.outliers.base import BaseDetector, DetectorMixin, DetectionResults from pyrca.utils.utils import estimate_thresholds @@ -36,6 +40,7 @@ class StatsDetectorConfig(BaseConfig): :param custom_anomaly_thresholds: Variable-specific anomaly detection thresholds other than default for certain variables. """ + default_sigma: float = 4.0 thres_win_size: int = 5 thres_reduce_func: str = "mean" @@ -56,6 +61,7 @@ class StatsDetector(BaseDetector, DetectorMixin): that abs(x - mean) > sigma * std. If this percentage is greater than a certain threshold, the timestamp t is considered as an anomaly. """ + config_class = StatsDetectorConfig def __init__(self, config: StatsDetectorConfig): @@ -68,11 +74,7 @@ def to_dict(self) -> Dict: """ Converts a trained detector into a python dictionary. """ - return { - "config": self.config.to_dict(), - "bounds": deepcopy(self.bounds), - "mean_stds": deepcopy(self.mean_stds) - } + return {"config": self.config.to_dict(), "bounds": deepcopy(self.bounds), "mean_stds": deepcopy(self.mean_stds)} @classmethod def from_dict(cls, d: Dict) -> StatsDetector: @@ -89,17 +91,15 @@ def _train(self, df, **kwargs): self.logger.warning("The training data contains NaN.") df = df.dropna() - sigmas = {} if self.config.sigmas is None else \ - self.config.sigmas - manual_thresholds = {} if self.config.manual_thresholds is None else \ - self.config.manual_thresholds + sigmas = {} if self.config.sigmas is None else self.config.sigmas + manual_thresholds = {} if self.config.manual_thresholds is None else self.config.manual_thresholds lowers, uppers, means, stds = estimate_thresholds( df=df, sigmas=sigmas, default_sigma=self.config.default_sigma, win_size=self.config.thres_win_size, reduce=self.config.thres_reduce_func, - return_mean_std=True + return_mean_std=True, ) for i, col in enumerate(df.columns): lower_bound = lowers[i] @@ -118,9 +118,12 @@ def _get_anomaly_scores(self, df): for i in range(len(df)): scores = [] for col in df.columns: - w = self.config.custom_win_sizes.get(col, self.config.score_win_size) \ - if self.config.custom_win_sizes else self.config.score_win_size - x = df[col].values[max(0, i - w): i + w] + w = ( + self.config.custom_win_sizes.get(col, self.config.score_win_size) + if self.config.custom_win_sizes + else self.config.score_win_size + ) + x = df[col].values[max(0, i - w) : i + w] y = (x < self.bounds[col][0]).astype(int) + (x > self.bounds[col][1]).astype(int) scores.append(y.sum() / len(y)) all_scores.append(scores) @@ -137,8 +140,11 @@ def _predict(self, df, **kwargs): anomalous_metrics = [] for metric, score in zip(df.columns, max_scores): - thres = self.config.custom_anomaly_thresholds.get(metric, self.config.anomaly_threshold) \ - if self.config.custom_anomaly_thresholds else self.config.anomaly_threshold + thres = ( + self.config.custom_anomaly_thresholds.get(metric, self.config.anomaly_threshold) + if self.config.custom_anomaly_thresholds + else self.config.anomaly_threshold + ) if score > thres: anomalous_metrics.append(metric) @@ -158,24 +164,28 @@ def _predict(self, df, **kwargs): anomaly_info[col] = { "normal_range": self.bounds[col], "mean_std": self.mean_stds.get(col, None), - "anomalies": [] + "anomalies": [], } for t in range(len(y)): if y[t] > 0: - anomaly_info[col]["anomalies"].append({ - "timestamp": timestamps[t], - "value": x[t], - "absolute_deviation": np.abs(x[t] - self.mean_stds[col][0]) - if col in self.mean_stds else None, - "z_score": np.abs(x[t] - self.mean_stds[col][0]) / (self.mean_stds[col][1] + 1e-5) - if col in self.mean_stds else None - }) + anomaly_info[col]["anomalies"].append( + { + "timestamp": timestamps[t], + "value": x[t], + "absolute_deviation": np.abs(x[t] - self.mean_stds[col][0]) + if col in self.mean_stds + else None, + "z_score": np.abs(x[t] - self.mean_stds[col][0]) / (self.mean_stds[col][1] + 1e-5) + if col in self.mean_stds + else None, + } + ) return DetectionResults( anomalous_metrics=anomalous_metrics, anomaly_timestamps=anomaly_timestamps, anomaly_labels=anomaly_labels, - anomaly_info=anomaly_info + anomaly_info=anomaly_info, ) def update_bounds(self, d: Dict): diff --git a/pyrca/tools/__init__.py b/pyrca/tools/__init__.py index e69de29..d0e4276 100644 --- a/pyrca/tools/__init__.py +++ b/pyrca/tools/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/tools/app.py b/pyrca/tools/app.py index d6f65cf..2e90529 100644 --- a/pyrca/tools/app.py +++ b/pyrca/tools/app.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from pyrca.tools.dashboard.dashboard import app server = app.server diff --git a/pyrca/tools/dashboard/models/__init__.py b/pyrca/tools/dashboard/models/__init__.py index e69de29..d0e4276 100644 --- a/pyrca/tools/dashboard/models/__init__.py +++ b/pyrca/tools/dashboard/models/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/tools/dashboard/models/causal.py b/pyrca/tools/dashboard/models/causal.py index 9002532..b9862f5 100644 --- a/pyrca/tools/dashboard/models/causal.py +++ b/pyrca/tools/dashboard/models/causal.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import sys import json @@ -19,7 +24,6 @@ class CausalDiscovery: - def __init__(self, folder): self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.DEBUG) @@ -141,9 +145,7 @@ def run(self, df, algorithm, params, constraints=None): config_class = self.get_supported_methods()[algorithm]["config_class"] method = method_class(config_class.from_dict(params)) graph_df = method.train( - df=df, - forbids=constraints.get("forbidden", []), - requires=constraints.get("required", []) + df=df, forbids=constraints.get("forbidden", []), requires=constraints.get("required", []) ) relations = self._extract_relations(graph_df) nx_graph = nx.from_pandas_adjacency(graph_df, create_using=nx.DiGraph()) diff --git a/pyrca/tools/dashboard/models/data.py b/pyrca/tools/dashboard/models/data.py index 6753b64..301d645 100644 --- a/pyrca/tools/dashboard/models/data.py +++ b/pyrca/tools/dashboard/models/data.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import sys import logging @@ -13,7 +18,6 @@ class DataAnalyzer: - def __init__(self, folder): self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.DEBUG) @@ -30,11 +34,14 @@ def load_data(self, file_name): @staticmethod def get_stats(df): stats = { - "@global": OrderedDict({ - "NO. of Variables": len(df.columns), - "Time Series Length": df.shape[0], - "Has NaNs": bool(df.isnull().values.any())}), - "@columns": list(df.columns) + "@global": OrderedDict( + { + "NO. of Variables": len(df.columns), + "Time Series Length": df.shape[0], + "Has NaNs": bool(df.isnull().values.any()), + } + ), + "@columns": list(df.columns), } for col in df.columns: info = df[col].describe() diff --git a/pyrca/tools/dashboard/settings.py b/pyrca/tools/dashboard/settings.py index 05bea47..97afbc7 100644 --- a/pyrca/tools/dashboard/settings.py +++ b/pyrca/tools/dashboard/settings.py @@ -1,2 +1,7 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# TABLE_HEADER_COLOR = "lightskyblue" TABLE_DATA_COLOR = "rgb(239, 243, 255)" diff --git a/pyrca/tools/dashboard/utils/file_manager.py b/pyrca/tools/dashboard/utils/file_manager.py index 22c2d15..685a205 100644 --- a/pyrca/tools/dashboard/utils/file_manager.py +++ b/pyrca/tools/dashboard/utils/file_manager.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import base64 import zipfile @@ -14,10 +19,8 @@ def __new__(cls): class FileManager(SingletonClass): - def __init__(self, directory=None): - self.directory = os.path.join(str(Path.home()), "pyrca") \ - if directory is None else directory + self.directory = os.path.join(str(Path.home()), "pyrca") if directory is None else directory if not os.path.exists(self.directory): os.makedirs(self.directory) @@ -30,8 +33,7 @@ def __init__(self, directory=None): os.makedirs(self.model_folder) self.cache_folder = os.path.join(self.directory, "cache") - self.long_callback_manager = DiskcacheLongCallbackManager( - diskcache.Cache(self.cache_folder)) + self.long_callback_manager = DiskcacheLongCallbackManager(diskcache.Cache(self.cache_folder)) def save_file(self, name, content): data = content.encode("utf8").split(b";base64,")[1] diff --git a/pyrca/tools/dashboard/utils/log.py b/pyrca/tools/dashboard/utils/log.py index b735786..79f009b 100644 --- a/pyrca/tools/dashboard/utils/log.py +++ b/pyrca/tools/dashboard/utils/log.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import logging diff --git a/pyrca/tools/dashboard/utils/plot.py b/pyrca/tools/dashboard/utils/plot.py index 9760199..bc57566 100644 --- a/pyrca/tools/dashboard/utils/plot.py +++ b/pyrca/tools/dashboard/utils/plot.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import plotly import plotly.graph_objects as go import pandas as pd @@ -9,8 +14,7 @@ def data_table(df, n=1000, page_size=10): if df is not None: df = df.head(n) - columns = [{"name": "Index", "id": "Index"}] + \ - [{"name": c, "id": c} for c in df.columns] + columns = [{"name": "Index", "id": "Index"}] + [{"name": c, "id": c} for c in df.columns] data = [] for i in range(df.shape[0]): d = {c: v for c, v in zip(df.columns, df.values[i])} @@ -29,7 +33,7 @@ def data_table(df, n=1000, page_size=10): page_size=page_size, page_current=0, style_header=dict(backgroundColor=TABLE_HEADER_COLOR), - style_data=dict(backgroundColor=TABLE_DATA_COLOR) + style_data=dict(backgroundColor=TABLE_DATA_COLOR), ) return table else: @@ -46,13 +50,9 @@ def plot_timeseries(ts, figure_height=750): for i in range(ts.shape[1]): v = ts[[ts.columns[i]]] color = color_list[index % len(color_list)] - traces.append(go.Scatter( - name=ts.columns[i], - x=v.index, - y=v.values.flatten(), - mode="lines", - line=dict(color=color) - )) + traces.append( + go.Scatter(name=ts.columns[i], x=v.index, y=v.values.flatten(), mode="lines", line=dict(color=color)) + ) index += 1 layout = dict( @@ -70,7 +70,7 @@ def plot_timeseries(ts, figure_height=750): dict(step="all"), ] ) - ) + ), ), ) fig = make_subplots(figure=go.Figure(layout=layout)) diff --git a/pyrca/utils/__init__.py b/pyrca/utils/__init__.py index e69de29..d0e4276 100644 --- a/pyrca/utils/__init__.py +++ b/pyrca/utils/__init__.py @@ -0,0 +1,5 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# diff --git a/pyrca/utils/domain.py b/pyrca/utils/domain.py index 6918d3b..efe463f 100644 --- a/pyrca/utils/domain.py +++ b/pyrca/utils/domain.py @@ -1,10 +1,15 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import yaml import pprint from schema import Schema, And, Or, Optional, SchemaError -_larger_than = (lambda target: lambda val: val > target) +_larger_than = lambda target: lambda val: val > target _larger_than_zero = _larger_than(0) -_less_than = (lambda target: lambda val: val < target) +_less_than = lambda target: lambda val: val < target _less_than_one = _less_than(1) SCHEMA = Schema( @@ -14,38 +19,36 @@ "root-nodes": Or(list, None), "leaf-nodes": Or(list, None), "forbids": Or(list, None), - "requires": Or(list, None) + "requires": Or(list, None), }, - - Optional("root-causes"): [{ - "name": str, - "P(r=1)": And(float, _larger_than_zero, _less_than_one), - "metrics": [ - { - "name": str, - Optional("P(m=0|r=0)"): And(float, _larger_than_zero, _less_than_one), - Optional("P(m=0|r=1)"): And(float, _larger_than_zero, _less_than_one) - } - ] - }] + Optional("root-causes"): [ + { + "name": str, + "P(r=1)": And(float, _larger_than_zero, _less_than_one), + "metrics": [ + { + "name": str, + Optional("P(m=0|r=0)"): And(float, _larger_than_zero, _less_than_one), + Optional("P(m=0|r=1)"): And(float, _larger_than_zero, _less_than_one), + } + ], + } + ], } ) class DomainParser: - def __init__(self, file_path): if file_path is None: self.config = None else: - with open(file_path, 'r') as f: + with open(file_path, "r") as f: self.config = yaml.safe_load(f) try: SCHEMA.validate(self.config) except SchemaError as e: - raise RuntimeError( - "The domain knowledge config does not fit the required schema." - ) from e + raise RuntimeError("The domain knowledge config does not fit the required schema.") from e def get_forbid_links(self, graph_nodes=None): if self.config is None or "causal-graph" not in self.config: @@ -77,20 +80,20 @@ def get_forbid_links(self, graph_nodes=None): if len(other_forbids) == 0: return forbids else: - return other_forbids if forbids is None \ - else forbids + other_forbids + return other_forbids if forbids is None else forbids + other_forbids def get_require_links(self): - return None if self.config is None or "causal-graph" not in self.config \ + return ( + None + if self.config is None or "causal-graph" not in self.config else self.config["causal-graph"]["requires"] + ) def get_root_causes(self): - return [] if self.config is None or "root-causes" not in self.config \ - else self.config["root-causes"] + return [] if self.config is None or "root-causes" not in self.config else self.config["root-causes"] def get_metrics(self): - return None if self.config is None or "metrics" not in self.config \ - else self.config["metrics"] + return None if self.config is None or "metrics" not in self.config else self.config["metrics"] def print(self): pprint.pprint(self.config) diff --git a/pyrca/utils/logger.py b/pyrca/utils/logger.py index 9036926..4a33b72 100644 --- a/pyrca/utils/logger.py +++ b/pyrca/utils/logger.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import logging diff --git a/pyrca/utils/misc.py b/pyrca/utils/misc.py index 104d654..83653a1 100644 --- a/pyrca/utils/misc.py +++ b/pyrca/utils/misc.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import sys import importlib.util from abc import ABCMeta @@ -34,8 +39,7 @@ def is_pycausal_available(): if importlib.util.find_spec("pycausal") is not None: _version = importlib_metadata.version("pycausal") if version.parse(_version) != version.parse("1.1.1"): - raise EnvironmentError(f"pycausal found but with version {_version}. " - f"The require version is 1.1.1.") + raise EnvironmentError(f"pycausal found but with version {_version}. " f"The require version is 1.1.1.") return True else: return False diff --git a/pyrca/utils/plot.py b/pyrca/utils/plot.py index b6e9329..10b806b 100644 --- a/pyrca/utils/plot.py +++ b/pyrca/utils/plot.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import plotly import networkx as nx import matplotlib.pyplot as plt @@ -11,26 +16,30 @@ def plotly_plot(df, extra_df=None): for i in range(df.shape[1]): v = df[[df.columns[i]]] color = color_list[index % len(color_list)] - traces.append(go.Scatter( - name=f"{df.columns[i]}", - x=v.index, - y=v.values.flatten().astype(float), - mode="lines", - line=dict(color=color) - )) + traces.append( + go.Scatter( + name=f"{df.columns[i]}", + x=v.index, + y=v.values.flatten().astype(float), + mode="lines", + line=dict(color=color), + ) + ) index += 1 if extra_df is not None: for i in range(extra_df.shape[1]): v = extra_df[[extra_df.columns[i]]] color = color_list[index % len(color_list)] - traces.append(go.Scatter( - name=f"{extra_df.columns[i]}_extra", - x=v.index, - y=v.values.flatten().astype(float), - mode="lines", - line=dict(color=color) - )) + traces.append( + go.Scatter( + name=f"{extra_df.columns[i]}_extra", + x=v.index, + y=v.values.flatten().astype(float), + mode="lines", + line=dict(color=color), + ) + ) index += 1 layout = dict( @@ -48,7 +57,7 @@ def plotly_plot(df, extra_df=None): dict(step="all"), ] ) - ) + ), ), ) fig = make_subplots(figure=go.Figure(layout=layout)) @@ -61,18 +70,13 @@ def plotly_plot(df, extra_df=None): def plot_causal_graph_networkx(adjacency_df, node_sizes): graph = nx.from_pandas_adjacency(adjacency_df, create_using=nx.DiGraph) pos = nx.layout.circular_layout(graph) - nx.draw_networkx_nodes( - graph, - pos, - nodelist=list(node_sizes.keys()), - node_size=list(node_sizes.values()) - ) + nx.draw_networkx_nodes(graph, pos, nodelist=list(node_sizes.keys()), node_size=list(node_sizes.values())) nx.draw_networkx_edges( graph, pos, - arrowstyle='->', + arrowstyle="->", arrowsize=15, - edge_color='c', + edge_color="c", width=1.5, ) nx.draw_networkx_labels(graph, pos, labels={c: c for c in adjacency_df.columns}) diff --git a/pyrca/utils/utils.py b/pyrca/utils/utils.py index 2779b8c..8500ff3 100644 --- a/pyrca/utils/utils.py +++ b/pyrca/utils/utils.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import numpy as np import pandas as pd from sklearn.preprocessing import MinMaxScaler @@ -6,18 +11,12 @@ class Scaler: - - scalars = { - "minmax": MinMaxScaler, - "robust": RobustScaler, - "standard": StandardScaler, - "none": None - } + scalars = {"minmax": MinMaxScaler, "robust": RobustScaler, "standard": StandardScaler, "none": None} def __init__(self, scalar_type="standard"): - assert scalar_type in Scaler.scalars, \ - f"The scalar type {scalar_type} is not supported. " \ - f"Please choose from {Scaler.scalars.keys()}." + assert scalar_type in Scaler.scalars, ( + f"The scalar type {scalar_type} is not supported. " f"Please choose from {Scaler.scalars.keys()}." + ) scaler_class = Scaler.scalars[scalar_type] self.scaler = None if scaler_class is None else scaler_class() @@ -46,8 +45,8 @@ def remove_outliers(df, scale=5.0): medians = np.median(data, axis=0) a = np.percentile(data, 99, axis=0) b = np.percentile(data, 1, axis=0) - max_value = ((a - medians) * scale + medians) - min_value = ((b - medians) * scale + medians) + max_value = (a - medians) * scale + medians + min_value = (b - medians) * scale + medians indices = [] for i in range(data.shape[0]): @@ -83,17 +82,9 @@ def timeseries_window(df, begin_date, end_date): return df -def estimate_thresholds( - df, - sigmas, - default_sigma=4, - win_size=5, - reduce="mean", - return_mean_std=False -): +def estimate_thresholds(df, sigmas, default_sigma=4, win_size=5, reduce="mean", return_mean_std=False): x = df.values - x = np.array([np.mean(x[max(0, i - win_size):i + 1, :], axis=0) - for i in range(x.shape[0])]) + x = np.array([np.mean(x[max(0, i - win_size) : i + 1, :], axis=0) for i in range(x.shape[0])]) a = np.percentile(x, 0.1, axis=0) b = np.percentile(x, 99.9, axis=0) x = np.maximum(np.minimum(x, b), a) diff --git a/setup.py b/setup.py index d6aeb74..cfee337 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,12 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# from setuptools import setup, find_namespace_packages extras_require = { - "plot": [ - "plotly>=4", - "dash>=2.0", - "dash_bootstrap_components>=1.0", - "jupyter-dash>=0.4", - "dash[diskcache]" - ] + "plot": ["plotly>=4", "dash>=2.0", "dash_bootstrap_components>=1.0", "jupyter-dash>=0.4", "dash[diskcache]"] } extras_require["all"] = sum(extras_require.values(), []) @@ -37,7 +36,7 @@ "tqdm", "wheel", "packaging", - "javabridge>=1.0.11" + "javabridge>=1.0.11", ], extras_require=extras_require, python_requires=">=3.7,<4", diff --git a/tests/analyzers/test_bayesian.py b/tests/analyzers/test_bayesian.py index 377379a..102289d 100644 --- a/tests/analyzers/test_bayesian.py +++ b/tests/analyzers/test_bayesian.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import unittest import numpy as np import pandas as pd @@ -5,18 +10,10 @@ class TestBayesianNetwork(unittest.TestCase): - def setUp(self) -> None: columns = ["a", "b", "c", "d"] self.graph = pd.DataFrame( - [ - [0, 1, 0, 0], - [0, 0, 1, 1], - [0, 0, 0, 0], - [0, 0, 0, 0] - ], - columns=columns, - index=columns + [[0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]], columns=columns, index=columns ) np.random.seed(0) self.df = pd.DataFrame(np.random.randn(100, 4), columns=columns) diff --git a/tests/analyzers/test_epsilon_diagnosis.py b/tests/analyzers/test_epsilon_diagnosis.py index ea1cd26..f76a1e9 100644 --- a/tests/analyzers/test_epsilon_diagnosis.py +++ b/tests/analyzers/test_epsilon_diagnosis.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import unittest import pandas as pd import numpy as np @@ -6,7 +11,6 @@ class TestEpsilonDiagnosis(unittest.TestCase): - def gen_random(self, n: int, d: int, covar: float) -> np.ndarray: cov_mat = np.ones((d, d)) * covar np.fill_diagonal(cov_mat, 1) diff --git a/tests/analyzers/test_psi_pc.py b/tests/analyzers/test_psi_pc.py index 6b22a35..a2347de 100644 --- a/tests/analyzers/test_psi_pc.py +++ b/tests/analyzers/test_psi_pc.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import pytest import unittest @@ -7,7 +12,6 @@ class TestEplisonDiagnosis(unittest.TestCase): - @pytest.mark.skip(reason="pickle issue") def test(self): # SRC_DIR = '../data/n-10-d-3-an-1-nor-s-1000-an-s-1000/' @@ -30,8 +34,8 @@ def test(self): data = pkl.load(input_file) # get normal and abnormal dataset in pd.DataFrame - training_samples = data['data']['num_samples'] - tot_data = data['data']['data'] + training_samples = data["data"]["num_samples"] + tot_data = data["data"]["data"] names = [("A%d" % (i + 1)) for i in range(tot_data.shape[1])] normal_data = tot_data[:training_samples] diff --git a/tests/analyzers/test_random_walk.py b/tests/analyzers/test_random_walk.py index dd9ece7..bcf61da 100644 --- a/tests/analyzers/test_random_walk.py +++ b/tests/analyzers/test_random_walk.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import unittest import numpy as np import pandas as pd @@ -5,18 +10,10 @@ class TestRandomWalk(unittest.TestCase): - def setUp(self) -> None: columns = ["a", "b", "c", "d"] self.graph = pd.DataFrame( - [ - [0, 1, 0, 0], - [0, 0, 1, 1], - [0, 0, 0, 0], - [0, 0, 0, 0] - ], - columns=columns, - index=columns + [[0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]], columns=columns, index=columns ) np.random.seed(0) self.df = pd.DataFrame(np.random.randn(100, 4), columns=columns) diff --git a/tests/analyzers/test_rht.py b/tests/analyzers/test_rht.py index 48773c1..9f7fb67 100644 --- a/tests/analyzers/test_rht.py +++ b/tests/analyzers/test_rht.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import pytest import unittest @@ -8,7 +13,6 @@ class TestRHT(unittest.TestCase): - @pytest.mark.skip(reason="pickle issue") def test(self): directory = os.path.dirname(os.path.abspath(__file__)) @@ -21,8 +25,8 @@ def test(self): data = pkl.load(input_file) # get normal and abnormal dataset in pd.DataFrame - training_samples = data['data']['num_samples'] - tot_data = data['data']['data'] + training_samples = data["data"]["num_samples"] + tot_data = data["data"]["data"] names = [("X%d" % (i + 1)) for i in range(tot_data.shape[1])] normal_data = tot_data[:training_samples] @@ -33,7 +37,7 @@ def test(self): model = RHT(config=RHTConfig(graph=graph)) model.train(normal_data_pd) - results = model.find_root_causes(abnormal_data_pd, 'X1', True).to_list() + results = model.find_root_causes(abnormal_data_pd, "X1", True).to_list() print(results) diff --git a/tests/applications/example/run_rca.py b/tests/applications/example/run_rca.py index 1ee43fc..321d7c9 100644 --- a/tests/applications/example/run_rca.py +++ b/tests/applications/example/run_rca.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import pprint import numpy as np @@ -8,33 +13,25 @@ def build_bayesian_network(): directory = os.path.dirname(os.path.abspath(__file__)) - df = load_data( - directory=os.path.join(directory, "../../data"), - filename="example.csv" - ) + df = load_data(directory=os.path.join(directory, "../../data"), filename="example.csv") engine = RCAEngine() - engine.build_causal_graph( - df=df, - run_pdag2dag=True, - max_num_points=5000000, - verbose=True + engine.build_causal_graph(df=df, run_pdag2dag=True, max_num_points=5000000, verbose=True) + bn = engine.train_bayesian_network( + dfs=[ + pd.DataFrame( + df.values + np.random.randn(*df.shape) * 1e-5, # To avoid constant values + columns=df.columns, + index=df.index, + ) + ] ) - bn = engine.train_bayesian_network(dfs=[ - pd.DataFrame( - df.values + np.random.randn(*df.shape) * 1e-5, # To avoid constant values - columns=df.columns, - index=df.index - ) - ]) bn.print_probabilities() def test_root_causes(): engine = RCAEngine() - result = engine.find_root_causes_bn( - anomalies=["conn_pool", "apt"] - ) + result = engine.find_root_causes_bn(anomalies=["conn_pool", "apt"]) pprint.pprint(result) diff --git a/tests/graphs/test_domain.py b/tests/graphs/test_domain.py index b1ba1ed..d684dc9 100644 --- a/tests/graphs/test_domain.py +++ b/tests/graphs/test_domain.py @@ -1,17 +1,27 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import unittest from pyrca.utils.domain import DomainParser class TestDomain(unittest.TestCase): - def test(self): directory = os.path.dirname(os.path.abspath(__file__)) parser = DomainParser(os.path.join(directory, "../data/domain_knowledge.yaml")) - results = [['A', 'B'], ['APT', 'DB_CPU'], - ['APT', 'APP_GC'], ['APP_GC', 'DB_CPU'], ['APT', 'DB_CPU'], - ['DB_CPU', 'APP_GC'], ['APT', 'APP_GC']] + results = [ + ["A", "B"], + ["APT", "DB_CPU"], + ["APT", "APP_GC"], + ["APP_GC", "DB_CPU"], + ["APT", "DB_CPU"], + ["DB_CPU", "APP_GC"], + ["APT", "APP_GC"], + ] forbids = parser.get_forbid_links(graph_nodes=["DB_CPU", "APP_GC", "APT"]) self.assertEqual(len(results), len(forbids)) for a, b in zip(results, forbids): @@ -21,7 +31,7 @@ def test(self): causes = parser.get_root_causes()[0] self.assertEqual(causes["name"], "Root_APP_GC") self.assertEqual(causes["P(r=1)"], 0.5) - self.assertEqual(causes["metrics"][0], {'name': 'APP_GC', 'P(m=0|r=0)': 0.99, 'P(m=0|r=1)': 0.01}) + self.assertEqual(causes["metrics"][0], {"name": "APP_GC", "P(m=0|r=0)": 0.99, "P(m=0|r=1)": 0.01}) if __name__ == "__main__": diff --git a/tests/graphs/test_fges.py b/tests/graphs/test_fges.py index 9a3cf09..f42cc14 100644 --- a/tests/graphs/test_fges.py +++ b/tests/graphs/test_fges.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import unittest import numpy as np @@ -6,12 +11,12 @@ class TestFGES(unittest.TestCase): - def test(self): directory = os.path.dirname(os.path.abspath(__file__)) data = np.loadtxt(os.path.join(directory, "../data/data_linear.txt"), skiprows=1) try: from pyrca.thirdparty.causallearn.utils.TXT2GeneralGraph import txt2generalgraph + graph = txt2generalgraph(os.path.join(directory, "../data/graph.txt")) df = pd.DataFrame(data, columns=[f"X{i}" for i in range(1, 21)]) graph = pd.DataFrame((graph.graph < 0).astype(int), columns=df.columns, index=df.columns) diff --git a/tests/graphs/test_ges.py b/tests/graphs/test_ges.py index a90d08d..04917b6 100644 --- a/tests/graphs/test_ges.py +++ b/tests/graphs/test_ges.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import unittest import numpy as np @@ -6,12 +11,12 @@ class TestGES(unittest.TestCase): - def test(self): directory = os.path.dirname(os.path.abspath(__file__)) data = np.loadtxt(os.path.join(directory, "../data/data_linear.txt"), skiprows=1) try: from pyrca.thirdparty.causallearn.utils.TXT2GeneralGraph import txt2generalgraph + graph = txt2generalgraph(os.path.join(directory, "../data/graph.txt")) df = pd.DataFrame(data, columns=[f"X{i}" for i in range(1, 21)]) graph = pd.DataFrame((graph.graph < 0).astype(int), columns=df.columns, index=df.columns) diff --git a/tests/graphs/test_lingam.py b/tests/graphs/test_lingam.py index 5adc927..fd3a0c3 100644 --- a/tests/graphs/test_lingam.py +++ b/tests/graphs/test_lingam.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import unittest import numpy as np import pandas as pd @@ -5,11 +10,10 @@ class TestLINGAM(unittest.TestCase): - def setUp(self) -> None: np.random.seed(0) sample_size = 1000 - columns = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5'] + columns = ["x0", "x1", "x2", "x3", "x4", "x5"] x3 = np.random.uniform(size=sample_size) x0 = 3.0 * x3 + np.random.uniform(size=sample_size) x2 = 6.0 * x3 + np.random.uniform(size=sample_size) @@ -24,8 +28,10 @@ def setUp(self) -> None: [0, 1, 0, 0, 1, 0], [1, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0] - ], columns=columns, index=columns + [0, 0, 0, 0, 0, 0], + ], + columns=columns, + index=columns, ) def test(self): diff --git a/tests/graphs/test_pc.py b/tests/graphs/test_pc.py index 0a480ee..32dd97b 100644 --- a/tests/graphs/test_pc.py +++ b/tests/graphs/test_pc.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import unittest import numpy as np @@ -6,12 +11,12 @@ class TestPC(unittest.TestCase): - def test(self): directory = os.path.dirname(os.path.abspath(__file__)) data = np.loadtxt(os.path.join(directory, "../data/data_linear.txt"), skiprows=1) try: from pyrca.thirdparty.causallearn.utils.TXT2GeneralGraph import txt2generalgraph + graph = txt2generalgraph(os.path.join(directory, "../data/graph.txt")) df = pd.DataFrame(data, columns=[f"X{i}" for i in range(1, 21)]) graph = pd.DataFrame((graph.graph < 0).astype(int), columns=df.columns, index=df.columns) diff --git a/tests/outliers/test_config.py b/tests/outliers/test_config.py index 880f30f..f5dc76a 100644 --- a/tests/outliers/test_config.py +++ b/tests/outliers/test_config.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import os import unittest from typing import Dict, List @@ -14,18 +19,17 @@ class TestConfig(BaseConfig): class TestBaseConfig(unittest.TestCase): - def test_dict(self): config = TestConfig() - self.assertDictEqual(config.to_dict(), {'A': 1, 'B': [0, 1, 2], 'C': {'a': 1}}) - config = TestConfig.from_dict({'A': 2, 'B': [1, 2], 'C': {'b': 1}}) - self.assertDictEqual(config.to_dict(), {'A': 2, 'B': [1, 2], 'C': {'b': 1}}) + self.assertDictEqual(config.to_dict(), {"A": 1, "B": [0, 1, 2], "C": {"a": 1}}) + config = TestConfig.from_dict({"A": 2, "B": [1, 2], "C": {"b": 1}}) + self.assertDictEqual(config.to_dict(), {"A": 2, "B": [1, 2], "C": {"b": 1}}) def test_json(self): config = TestConfig() self.assertEqual(config.to_json(), '{"A": 1, "B": [0, 1, 2], "C": {"a": 1}}') config = TestConfig.from_json('{"A": 2, "B": [1, 2], "C": {"b": 1}}') - self.assertDictEqual(config.to_dict(), {'A': 2, 'B': [1, 2], 'C': {'b': 1}}) + self.assertDictEqual(config.to_dict(), {"A": 2, "B": [1, 2], "C": {"b": 1}}) def test_yaml(self): filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/outlier_configs.yaml") @@ -33,16 +37,16 @@ def test_yaml(self): self.assertDictEqual( config.to_dict(), { - 'default_sigma': 4, - 'thres_win_size': 5, - 'thres_reduce_func': 'mean', - 'score_win_size': 3, - 'anomaly_threshold': 0.5, - 'sigmas': None, - 'manual_thresholds': None, - 'custom_win_sizes': None, - 'custom_anomaly_thresholds': None - } + "default_sigma": 4, + "thres_win_size": 5, + "thres_reduce_func": "mean", + "score_win_size": 3, + "anomaly_threshold": 0.5, + "sigmas": None, + "manual_thresholds": None, + "custom_win_sizes": None, + "custom_anomaly_thresholds": None, + }, ) diff --git a/tests/outliers/test_stats.py b/tests/outliers/test_stats.py index 26b6407..aedcf46 100644 --- a/tests/outliers/test_stats.py +++ b/tests/outliers/test_stats.py @@ -1,3 +1,8 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import unittest import numpy as np import pandas as pd @@ -5,7 +10,6 @@ class TestStatsDetector(unittest.TestCase): - def test(self): np.random.seed(0) x = np.random.randn(50) * 0.1 @@ -25,12 +29,12 @@ def test(self): results = detector.predict(df).to_dict() self.assertListEqual(results["anomalous_metrics"], [0]) - self.assertEqual(results["anomaly_timestamps"][0][0], np.datetime64('1970-01-01T00:20:00.000000000')) + self.assertEqual(results["anomaly_timestamps"][0][0], np.datetime64("1970-01-01T00:20:00.000000000")) detector = StatsDetector.from_dict(detector.to_dict()) results = detector.predict(df).to_dict() self.assertListEqual(results["anomalous_metrics"], [0]) - self.assertEqual(results["anomaly_timestamps"][0][0], np.datetime64('1970-01-01T00:20:00.000000000')) + self.assertEqual(results["anomaly_timestamps"][0][0], np.datetime64("1970-01-01T00:20:00.000000000")) def test_update_config(self): config = { @@ -39,22 +43,10 @@ def test_update_config(self): "thres_reduce_func": "mean", "score_win_size": 3, "anomaly_threshold": 0.5, - "manual_thresholds": { - "Connection_Pool_Errors": { - "lower": 0.0, - "upper": 10.0 - } - } + "manual_thresholds": {"Connection_Pool_Errors": {"lower": 0.0, "upper": 10.0}}, } detector = StatsDetector(StatsDetector.config_class.from_dict(config)) - detector.update_config({ - "manual_thresholds": { - "Connection_Pool_Errors": { - "lower": 1.0, - "upper": 9.0 - } - } - }) + detector.update_config({"manual_thresholds": {"Connection_Pool_Errors": {"lower": 1.0, "upper": 9.0}}}) d = detector.config.to_dict() self.assertEqual(d["manual_thresholds"]["Connection_Pool_Errors"]["lower"], 1.0) self.assertEqual(d["manual_thresholds"]["Connection_Pool_Errors"]["upper"], 9.0) @@ -72,7 +64,7 @@ def test_update_bounds(self): results = detector.predict(df).to_dict() self.assertListEqual(results["anomalous_metrics"], [0]) - self.assertEqual(results["anomaly_timestamps"][0][0], np.datetime64('1970-01-01T00:20:00.000000000')) + self.assertEqual(results["anomaly_timestamps"][0][0], np.datetime64("1970-01-01T00:20:00.000000000")) if __name__ == "__main__": diff --git a/tests/tools/test_causal.py b/tests/tools/test_causal.py index a866fd6..0eefeb2 100644 --- a/tests/tools/test_causal.py +++ b/tests/tools/test_causal.py @@ -1,15 +1,19 @@ +# +# Copyright (c) 2023 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause# import unittest import pandas as pd from pyrca.tools.dashboard.models.causal import CausalDiscovery class TestCausal(unittest.TestCase): - def test_1(self): graph = pd.DataFrame( [[0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]], columns=["a", "b", "c", "d"], - index=["a", "b", "c", "d"] + index=["a", "b", "c", "d"], ) levels, cycles = CausalDiscovery.causal_order(graph) self.assertEqual(levels, None) @@ -19,7 +23,7 @@ def test_2(self): graph = pd.DataFrame( [[0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0]], columns=["a", "b", "c", "d"], - index=["a", "b", "c", "d"] + index=["a", "b", "c", "d"], ) levels, cycles = CausalDiscovery.causal_order(graph) self.assertEqual(cycles, None) @@ -30,7 +34,7 @@ def test_3(self): graph = pd.DataFrame( [[0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1], [0, 0, 1, 0, 0]], columns=["a", "b", "c", "d", "e"], - index=["a", "b", "c", "d", "e"] + index=["a", "b", "c", "d", "e"], ) levels, cycles = CausalDiscovery.causal_order(graph) self.assertEqual(levels, None)