In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy.stats import chi2
from sklearn import metrics

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sns.set_theme()
plt.rcParams['figure.figsize'] = [8,8]

In [None]:
sharks = pd.read_csv("../datasets/sharks.csv")
sharks

In [None]:
threatened = ["Critically Endangered","Endangered","Vulnerable"]

In [None]:
sharks["Threatened"] = sharks["Category"].isin( threatened ).astype('int')
sharks = sharks.drop(columns = "Category")
sharks

In [None]:
sharks["LogWeight"] = np.log(sharks["Weight"])
sharks_model = smf.glm("Threatened ~ LogWeight", data=sharks, family=sm.families.Binomial())
sharks_fit = sharks_model.fit()
sharks_fit.params

In [None]:
1 - chi2.cdf(sharks_fit.null_deviance - sharks_fit.deviance, df=1)

In [None]:
sns.regplot(data=sharks, x="LogWeight", y="Threatened", logistic=True, ci=None)
# plt.savefig("sharks_fit.png")

In [None]:
sharks["Class"] = (sharks_fit.fittedvalues > 0.50).astype(int)
sharks

In [None]:
pd.crosstab(sharks["Threatened"], sharks["Class"])

In [None]:
np.mean( sharks["Threatened"] == sharks["Class"] )

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(sharks["Threatened"], sharks_fit.fittedvalues)
chronic_auc = metrics.auc(fpr, tpr)
chronic_auc

In [None]:
plt.figure()
plt.plot(fpr, tpr, label='ROC curve    AUC: %0.2f' % chronic_auc)
plt.plot([0,1], [0,1], 'r--', label='Random classification')
plt.xlabel('False Positive Rate (1-Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.title('ROC curve for shark extinction risk classifier')
plt.legend(loc="lower right")
# plt.savefig("sharks_roc.png")