In [None]:
# 1. Install requirements if needed
!pip install shap seaborn --quiet

# 2. Imports
import shap
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# 3. Synthetic data creation
def random_ip():
    return f"192.168.{np.random.randint(0,256)}.{np.random.randint(1,255)}"

np.random.seed(42)
data = pd.DataFrame({
    'IP_address': [random_ip() for _ in range(100)],
    'geo_location': np.random.choice(['US', 'CA', 'UK', 'DE', 'IN'], 100),
    'time': np.random.randint(0, 24, 100),
    'user_identity': np.random.choice(['userA', 'userB', 'userC'], 100)
})

# 4. Define risky sets
risky_ip_set = {'192.168.1.10', '192.168.2.20', '192.168.1.100'}
risky_times = {2, 3, 4}
risky_geo = {'IN', 'DE'}  # for example

# Label data as risky if any of the three are risky
data['target'] = data.apply(
    lambda row: 1 if (row['IP_address'] in risky_ip_set
                      or row['time'] in risky_times
                      or row['geo_location'] in risky_geo)
    else 0, axis=1
)
print(data.groupby("target").size())

# 5. One-hot encode categoricals
data_encoded = pd.get_dummies(data.drop('target', axis=1))
target = data['target']

# 6. Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    data_encoded, target, test_size=0.2, random_state=42
)

# 7. Train model
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)

# 8. SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
base_value = explainer.expected_value[1] if isinstance(explainer.expected_value, (list, np.ndarray)) else explainer.expected_value
base_value = float(np.ravel(base_value)[0])

if isinstance(shap_values, list):
    shap_matrix = shap_values[1]
else:
    if shap_values.ndim == 3:
        shap_matrix = shap_values[:,:,1]
    else:
        shap_matrix = shap_values
feature_names = X_test.columns

# 9. Compute risk only from SHAP of IP, time, geo
def risk_shap_ip_time_geo(i, orig_row, shap_matrix, feature_names, base_value):
    shap_sum = 0.0
    for feature_prefix, value in [
        ("IP_address", orig_row['IP_address']),
        ("time", orig_row['time']),
        ("geo_location", orig_row['geo_location'])
    ]:
        f_col = f"{feature_prefix}_{value}"
        if f_col in feature_names:
            idx = feature_names.get_loc(f_col)
            shap_value = shap_matrix[i, idx]
            if isinstance(shap_value, (np.ndarray, list)):
                shap_value = float(np.ravel(shap_value)[0])
            else:
                shap_value = float(shap_value)
            shap_sum += shap_value
    raw_score = float(base_value) + float(shap_sum)
    prob = 1 / (1 + np.exp(-raw_score))
    return float(prob * 100)

plot_data = data.loc[X_test.index].copy()
risk_percs = []
for offset, i in enumerate(X_test.index):
    orig_row = data.loc[i]
    perc = risk_shap_ip_time_geo(int(offset), orig_row, shap_matrix, feature_names, base_value)
    risk_percs.append(perc)
plot_data['shap_ip_time_geo_risk_pct'] = risk_percs

# 10. Visualization: mean risk for each IP (top 10)
mean_risk_ip = plot_data.groupby("IP_address")['shap_ip_time_geo_risk_pct'].mean().sort_values(ascending=False)
print("Top 10 Riskiest IPs (SHAP/IP+time+geo):")
print(mean_risk_ip.head(10))

plt.figure(figsize=(10,4))
sns.barplot(y=mean_risk_ip.head(10).index, x=mean_risk_ip.head(10).values, orient='h')
plt.xlabel("Mean SHAP-based Risk Score (%)")
plt.ylabel("IP Address")
plt.title("Top 10 IPs by SHAP-3-feature Risk")
plt.show()

# 11. Mean risk for each time (barplot)
mean_risk_time = plot_data.groupby("time")['shap_ip_time_geo_risk_pct'].mean().sort_values(ascending=False)
print("Riskiest times (SHAP/IP+time+geo):")
print(mean_risk_time.head())

plt.figure(figsize=(10,4))
sns.barplot(x=mean_risk_time.index, y=mean_risk_time.values)
plt.ylabel("Mean SHAP-based Risk Score (%)")
plt.xlabel("Hour (time)")
plt.title("Risk vs Time (SHAP, IP & Time & Geo)")
plt.show()

# 12. Mean risk for each geo location (barplot)
mean_risk_geo = plot_data.groupby("geo_location")['shap_ip_time_geo_risk_pct'].mean().sort_values(ascending=False)
print("Riskiest geos (SHAP/IP+time+geo):")
print(mean_risk_geo)

plt.figure(figsize=(8,4))
sns.barplot(x=mean_risk_geo.index, y=mean_risk_geo.values)
plt.ylabel("Mean SHAP-based Risk Score (%)")
plt.xlabel("Geo Location")
plt.title("Risk vs Geo Location (SHAP, IP & Time & Geo)")
plt.show()

# 13. SHAP beeswarm for model explainability
print('\nSHAP Beeswarm Plot (top 15 features):')
shap.summary_plot(shap_matrix, X_test, max_display=15)

# 14. Example: Show risk scores for a known risky IP, time, or geo
test_ip = '192.168.1.10'
rows_for_ip = plot_data[plot_data['IP_address'] == test_ip]
if not rows_for_ip.empty:
    print(f"\nRisk scores for IP {test_ip}: {rows_for_ip['shap_ip_time_geo_risk_pct'].values}")

test_time = 3
rows_for_time = plot_data[plot_data['time'] == test_time]
if not rows_for_time.empty:
    print(f"\nRisk scores for time {test_time}: {rows_for_time['shap_ip_time_geo_risk_pct'].values}")

test_geo = 'IN'
rows_for_geo = plot_data[plot_data['geo_location'] == test_geo]
if not rows_for_geo.empty:
    print(f"\nRisk scores for geo {test_geo}: {rows_for_geo['shap_ip_time_geo_risk_pct'].values}")

# 15. See data sample
print("\nSample rows with their (SHAP/IP+time+geo) risk %:")
print(plot_data[['IP_address','geo_location','time','shap_ip_time_geo_risk_pct']].head())