import pandas as pd
import numpy as np
import plotly.graph_objects as go
import math
# =========================================================
# SMART LOG SCALE
# =========================================================
def get_smart_log_range(values):
min_val = max(min(values), 1)
max_val = max(values)
lower_power = math.floor(
math.log10(min_val)
)
upper_power = math.ceil(
math.log10(max_val)
)
if min_val > (10 ** lower_power) * 5:
lower_power += 1
return lower_power, upper_power
# =========================================================
# MAIN CHART FUNCTION
# =========================================================
def create_enterprise_risk_chart(
df,
segment_col="SEGMENT",
risk_col="RISK_LEVEL",
amount_col="TXN_AMOUNT",
count_col="TXN_COUNT",
amount_threshold_col="AMOUNT_THRESHOLD",
count_threshold_col="COUNT_THRESHOLD",
title="Enterprise AML Dual Threshold Chart",
width=1200,
height=700
):
# =====================================================
# RISK STYLE
# =====================================================
risk_styles = {
"RR": {
"color": "#4472C4",
"symbol": "circle"
},
"MR": {
"color": "#ED7D31",
"symbol": "diamond"
},
"HR": {
"color": "#C00000",
"symbol": "square"
}
}
# =====================================================
# SEGMENTS
# =====================================================
segments = list(
df[segment_col].unique()
)
gap = 16
x_map = {
seg: idx * gap
for idx, seg in enumerate(segments)
}
df["BASE_X"] = (
df[segment_col]
.map(x_map)
)
# =====================================================
# SCATTER SPREAD
# =====================================================
np.random.seed(42)
df["X"] = (
df["BASE_X"] +
np.random.uniform(-3.5, 3.5, len(df))
)
# =====================================================
# Y AXIS
# =====================================================
all_values = list(df[amount_col])
all_values += list(df[amount_threshold_col])
lower_power, upper_power = (
get_smart_log_range(all_values)
)
tickvals = [
10 ** i
for i in range(
lower_power,
upper_power + 1
)
]
ticktext = [
f"{int(v):,}"
for v in tickvals
]
# =====================================================
# FIGURE
# =====================================================
fig = go.Figure()
# =====================================================
# SEGMENT SHADING
# =====================================================
shades = [
"rgba(180,180,180,0.05)",
"rgba(120,120,120,0.03)"
]
for idx, seg in enumerate(segments):
center = x_map[seg]
fig.add_vrect(
x0=center - 7,
x1=center + 7,
fillcolor=shades[
idx % len(shades)
],
line_width=0,
layer="below"
)
# =====================================================
# SCATTER POINTS
# =====================================================
risks = ["RR", "MR", "HR"]
for risk in risks:
temp = df[
df[risk_col] == risk
]
if temp.empty:
continue
# ===============================================
# AMOUNT POINTS
# ===============================================
fig.add_trace(
go.Scatter(
x=temp["X"],
y=temp[amount_col],
mode="markers",
marker=dict(
size=8,
color=risk_styles[risk]["color"],
symbol=risk_styles[risk]["symbol"],
opacity=0.82,
line=dict(
width=0.5,
color="black"
)
),
text=temp[segment_col],
hovertemplate=
"<b>Segment:</b> %{text}<br>" +
"<b>Risk:</b> " + risk + "<br>" +
"<b>Amount:</b> %{y:,.0f}<extra></extra>",
showlegend=False
)
)
# =====================================================
# AMOUNT THRESHOLD
# =====================================================
for risk in risks:
temp = (
df[
df[risk_col] == risk
]
.groupby(segment_col)[
amount_threshold_col
]
.first()
.reset_index()
)
for _, row in temp.iterrows():
center = x_map[
row[segment_col]
]
x0 = center - 7
x1 = center + 7
fig.add_shape(
type="line",
x0=x0,
x1=x1,
y0=row[amount_threshold_col],
y1=row[amount_threshold_col],
line=dict(
color=risk_styles[risk]["color"],
width=3,
dash="solid"
)
)
# =====================================================
# COUNT THRESHOLD
# =====================================================
# count threshold annotation
for risk in risks:
temp = (
df[
df[risk_col] == risk
]
.groupby(segment_col)[
count_threshold_col
]
.first()
.reset_index()
)
for _, row in temp.iterrows():
center = x_map[
row[segment_col]
]
# text under segment
fig.add_annotation(
x=center,
y=10 ** lower_power,
text=(
f"<b>{risk}</b><br>"
f"Count Th: "
f"{int(row[count_threshold_col]):,}"
),
showarrow=False,
yshift=-35,
font=dict(
size=9,
color=risk_styles[risk]["color"]
)
)
# =====================================================
# X LABELS
# =====================================================
tickvals_x = []
ticktext_x = []
cumulative = 0
for idx, seg in enumerate(segments):
center = x_map[seg]
tickvals_x.append(center)
count = len(
df[
df[segment_col] == seg
]
)
start = cumulative + 1
cumulative += count
end = cumulative
label = (
f"<b>{seg.replace('_', ' ')}</b>"
f"<br>"
f"Population: "
f"{start:,}-{end:,}"
)
ticktext_x.append(label)
# =====================================================
# VISUAL LEGEND
# =====================================================
for risk in risks:
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker=dict(
size=10,
color=risk_styles[risk]["color"],
symbol=risk_styles[risk]["symbol"],
line=dict(
width=0.5,
color="black"
)
),
name=f"{risk}"
)
)
# =====================================================
# LAYOUT
# =====================================================
fig.update_layout(
title=dict(
text=title,
x=0.5,
font=dict(size=20)
),
width=width,
height=height,
template="plotly_white",
paper_bgcolor="white",
plot_bgcolor="white",
hovermode="closest",
margin=dict(
l=80,
r=40,
t=70,
b=180
),
legend=dict(
orientation="h",
yanchor="top",
y=-0.18,
xanchor="center",
x=0.5
),
xaxis=dict(
title="Segments",
tickmode="array",
tickvals=tickvals_x,
ticktext=ticktext_x,
showgrid=False,
zeroline=False
),
yaxis=dict(
title="Transaction Amount",
type="log",
tickvals=tickvals,
ticktext=ticktext,
range=[
lower_power,
upper_power
],
gridcolor="rgba(220,220,220,0.6)",
zeroline=False
)
)
return fig
# =========================================================
# SAMPLE DATA
# =========================================================
np.random.seed(22)
segments = [
"BD_EXTERNAL_ENTITY",
"BD_EXTERNAL_OTHERS",
"BD_INTERNAL",
"BD_HIGH_RISK"
]
risks = ["RR", "MR", "HR"]
rows = []
for seg in segments:
for risk in risks:
# =================================================
# RISK CONFIG
# =================================================
if risk == "RR":
amount_threshold = np.random.randint(
500_000,
2_000_000
)
count_threshold = np.random.randint(
10,
25
)
elif risk == "MR":
amount_threshold = np.random.randint(
5_000_000,
15_000_000
)
count_threshold = np.random.randint(
30,
60
)
else:
amount_threshold = np.random.randint(
20_000_000,
50_000_000
)
count_threshold = np.random.randint(
70,
120
)
# =================================================
# TRANSACTIONS
# =================================================
for i in range(40):
amount = np.random.lognormal(
mean=np.log(amount_threshold * 1.5),
sigma=1.0
)
txn_count = np.random.randint(
1,
count_threshold * 2
)
rows.append({
"SEGMENT": seg,
"RISK_LEVEL": risk,
"TXN_AMOUNT": amount,
"TXN_COUNT": txn_count,
"AMOUNT_THRESHOLD": amount_threshold,
"COUNT_THRESHOLD": count_threshold
})
# =========================================================
# DATAFRAME
# =========================================================
df = pd.DataFrame(rows)
# =========================================================
# CREATE CHART
# =========================================================
fig = create_enterprise_risk_chart(
df=df
)
fig.show()