In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import arviz as az
import matplotlib.pyplot as plt
import pymc3 as pm
from scipy.optimize import curve_fit

plt.rcParams["figure.figsize"] = (10,10)
az.style.use('arviz-darkgrid')

In [None]:
df = pd.read_csv('/kaggle/input/wildlife-strikes/database.csv')

df = df[df['Incident Year']!=2015]

days_in_month_lut = [31,28,31,30,31,30,31,31,30,31,30,31]
days_in_month_lut_leap_year = [31,29,31,30,31,30,31,31,30,31,30,31]


years = []
months = []
incidents_in_month = []
days_in_months = []
incidents_in_year = []

for year in range(1990,2015):
    for month in range(1,13):
        df_sample = df[(df['Incident Year']==year) & (df['Incident Month']==month)]
        num_incidents = df_sample.shape[0]
        years.append(year)
        months.append(month)
        incidents_in_month.append(num_incidents)
        
        
        if year%4 == True:
            days_in_months.append(days_in_month_lut_leap_year[month-1])
        else:
            days_in_months.append(days_in_month_lut[month-1])

    df_sample = df[(df['Incident Year']==year)]
    num_incidents = df_sample.shape[0]
        
    for i in range(0,12):
        incidents_in_year.append(num_incidents)
    
data = pd.DataFrame({'year':years,'month':months,'incidents_in_month':incidents_in_month,'days_in_month':days_in_months,'incidents_in_year':incidents_in_year})
             
                    

In [None]:
plt.plot(data.incidents_in_month)
plt.xlabel('Month Number')
plt.ylabel('Wildlife Strikes per month')
plt.show()


#with pm.Model() as bird_strike_model:
#pm.Poisson

    
    

In [None]:
plt.plot(data.incidents_in_month/data.days_in_month)
plt.xlabel('Month Number')
plt.ylabel('Wildlife Strikes per day')
plt.show()    

In [None]:

def quad(x, a, b, c):
    return a*x**2 + b*x + c



popt, pcov = curve_fit(quad, data.year.values[::12]-1990, data.incidents_in_year.values[::12])

plt.plot(data.year.values[::12],data.incidents_in_year.values[::12])

t = np.linspace(0,25,100)

print(popt)
plt.plot(t+1990,quad(t, popt[0], popt[1], popt[2]))
plt.xlabel('Year')
plt.ylabel('Wildlife Strikes per year')

plt.show()



In [None]:

with pm.Model() as bird_strike_model:
    a = pm.Normal('a',mu=10, sigma=5)
    b = pm.Normal('b',mu=200, sigma=100)
    c = pm.Normal('c',mu=2200, sigma=1000)
    
    eps = pm.Normal('eps_hyper',mu=1000,sigma=1000) 
    
    t = np.arange(0,25)
    strikes = quad(t, a, b, c)
    
    strikes_observed = pm.Normal('strikes', mu = strikes, sigma = eps, observed = data.incidents_in_year.values[::12])
    
    
    trace = pm.sample(chains=2, draws=5_000, tune=2_000,target_accept=.90)





In [None]:
pm.traceplot(trace);

In [None]:

for i in range(0,200):
    t = np.arange(0,25)

    strikes = trace['a'][i] * t**2  + trace['b'][i] * t + trace['c'][i]

    plt.plot(t+1990, strikes, color='k', alpha=0.01)


plt.scatter(data.year.values[::12],data.incidents_in_year.values[::12],color='r')


plt.ylabel('Strikes')
plt.xlabel('Years')
plt.grid()
plt.show()




In [None]:
with pm.Model() as bird_strike_model:
    a = pm.Normal('a',mu=10, sigma=5)
    b = pm.Normal('b',mu=200, sigma=100)
    c = pm.Normal('c',mu=2200, sigma=1000)
    
    eps = pm.Normal('eps',mu=1000,sigma=1000) 
    
    seasonality = pm.Normal('seasonality',mu = 0, sigma = 1, shape = 12)
    
    
    t = (data.year-1990) + (data.month-1)/12.0
    strikes_trend = quad(t, a, b, c)/12.0
    
    
    
    strikes = strikes_trend * seasonality[data.month-1]
    
    
    strikes_observed_annual = pm.Normal('strikes', mu = strikes, sigma = eps, observed = data.incidents_in_month)
    
    
    trace = pm.sample(chains=2, draws=5_000, tune=2_000,target_accept=.99)

In [None]:
pm.traceplot(trace);

In [None]:
for i in range(0,50):
    t = (data.year-1990) + (data.month-1)/12.0
    strikes_trend = trace['a'][i] * t**2  + trace['b'][i] * t + trace['c'][i]
    
    strikes = (strikes_trend * trace['seasonality'][i][data.month-1])/12.0
    
    plt.plot(strikes, color='k', alpha=0.01)


plt.plot(data.incidents_in_month,color='r')


plt.ylabel('Strikes')
plt.xlabel('Years')
plt.grid()
plt.show()




In [None]:
pm.forestplot(trace,kind='forestplot',var_names =['seasonality'])