In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.formula.api as smf

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
sns.set_theme()
plt.rcParams['figure.figsize'] = [8,8]

In [None]:
fish = pd.read_csv("../datasets/fish.csv")
fish.shape

In [None]:
fish.columns

In [None]:
fish = fish.filter(["Species","Length","Weight"])
fish

In [None]:
train = fish.dropna(axis='index')
train

In [None]:
sns.scatterplot(data=train, x="Length", y="Weight")
plt.title("Fish weights versus fish lengths")
# plt.savefig("fish_lw.png")

In [None]:
train = train.assign(LogLength = np.log(train.Length), LogWeight = np.log(train.Weight))

In [None]:
sns.scatterplot(data=train, x="LogLength", y="LogWeight")
plt.title("Log(Weight) versus Log(Length)")
# plt.savefig("fish_logs.png")

In [None]:
train.LogLength.corr( train.LogWeight )

In [None]:
train_model = smf.ols( "LogWeight ~ LogLength", data=train)
train_fit = train_model.fit()
train_fit.params

In [None]:
train_fit.rsquared

In [None]:
sharks = pd.read_csv("../datasets/chondrichthyes.csv")
sharks.head()

In [None]:
sharks = sharks.join( fish.set_index("Species"), on="Species")
sharks.head()

In [None]:
sharks = sharks[ sharks.Length.notna() ]
sharks = sharks[ sharks.Category != "Data Deficient" ]
sharks.head()

In [None]:
imp = np.exp( train_fit.params.Intercept )*np.power( sharks.Length, 
                                                    train_fit.params.LogLength )
sharks.Weight = sharks.Weight.where( sharks.Weight.notna(), imp )
sharks

In [None]:
sharks = sharks.assign(LogLength = np.log(sharks.Length), 
                       LogWeight = np.log(sharks.Weight))

In [None]:
sns.scatterplot( data=sharks, x="LogLength", y="LogWeight")
plt.title("Log(Weight) versus Log(Length) for sharks")
# plt.savefig("sharks_lw.png")