# Introduction

We analyze the railways accidents data. 
We load also the ISO country codes data, as auxiliary data.  
Use of Sankey diagrams allows us to show in a single graph the distribution of accidents per types and per countries.

# Analysis preparation

## Load packages

In [None]:
import pandas as pd
import os

## Load data

In [None]:
data_df = pd.read_csv("/kaggle/input/railways-accidents-in-europe/railways_accidents_eu.csv")
country_codes_df = pd.read_csv("/kaggle/input/iso-country-codes-global/wikipedia-iso-country-codes.csv")

# Data exploration

## Glimpse the data

In [None]:
data_df.info()

In [None]:
country_codes_df.info()

In [None]:
data_df.head()

In [None]:
country_codes_df.head()

The country codes used in the asylum dataset correspond to the Alpha-2 codes in the ISO country code data. We will merge twice the two datasets to get as well the English short name countries names.

## Merge accidents data and country data

In [None]:
cc_df = country_codes_df[['English short name lower case','Alpha-2 code','Alpha-3 code']]
cc_df.columns = ['geography_name', 'geography', 'geography_3']
data_c_df = data_df.merge(cc_df, how='left')
print(data_df.shape, data_c_df.shape)
data_c_df.head()

## Top 10 countries with accidents

In [None]:
agg_df = data_c_df.groupby(['geography', 'geography_name'])['value'].sum().reset_index()
agg_df.sort_values(["value"], inplace=True, ascending=False)
agg_df.head(10)

In [None]:
top_10_accidents = agg_df.head(10).geography_name.values

## Sankey diagram


### Visualization function using Sankey diagram

In [None]:
import plotly.graph_objs as go
import plotly.figure_factory as ff
from plotly import tools
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram',param={"height":1000}):
    # maximum of 6 value cols -> 6 colors
    colorPalette = ['#4B8BBE', '#AF2346','#32CD32','#8B008B','#FFD43B','#646464']
    labelList = []
    colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
       
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.25
          ),
          label = labelList,
          color = colorList
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count'],
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        ),
        height=param["height"]
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

### 

In [None]:
agg_df = data_c_df.groupby(['accident', 'geography_name'])['value'].sum().reset_index()
agg_df.columns = ["accident", "country", "total"]
agg_df.sort_values(["total"], inplace=True, ascending=False)
print(f"All combinations: {agg_df.shape[0]}")
agg_df.head(10)

### Sankey diagram with type of accidents and country

In [None]:
data_agg = agg_df
fig = genSankey(data_agg,cat_cols=['country', 'accident'],\
                value_cols='total',
                title='Sankey Diagram for railways accidents: {country -> accident type}')
iplot(fig, validate=False)

Let's remove TOTAL.

In [None]:
data_agg = agg_df.loc[agg_df.accident != "TOTAL"]
fig = genSankey(data_agg,cat_cols=['country', 'accident'],\
                value_cols='total',
                title='Sankey Diagram for railways accidents: {country -> accident type}')
iplot(fig, validate=False)

## Sankey Diagram for country, accident type and year

Let's now show also the year.
We remove as well TOTAL type of accident.

In [None]:
agg_df = data_c_df.groupby(['accident', 'date', 'geography_name'])['value'].sum().reset_index()
agg_df.columns = ["accident", "year", "country", "total"]
agg_df.sort_values(["total"], inplace=True, ascending=False)
print(f"All combinations: {agg_df.shape[0]}")
agg_df.head(10)

In [None]:
data_agg = agg_df.loc[agg_df.accident != "TOTAL"]
fig = genSankey(data_agg,cat_cols=['country', 'year', 'accident'],\
                value_cols='total',
                title='Sankey Diagram for railways accidents: {country -> year -> accident type}')
iplot(fig, validate=False)

## Time evolution of railways accidents

We will look now to the trends of railways accidents.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns 
def plot_time_variation(df, c='geography_name', y='value', is_log=False, title=""):
    f, ax = plt.subplots(1,1, figsize=(16,12))
    countries = df[c].unique()
    for country in countries:
        df_ = df[(df[c]==country)] 
        df_[y] = df_[y] + 1
        g = sns.lineplot(x="date", y=y, data=df_,  label=country)  
        ax.text(max(df_['date']), (df_.loc[df_['date']==max(df_['date']), y]), str(country))
    plt.xticks(rotation=90)
    plt.title(f'Total {title}, grouped by country/year')
    ax.text(max(df_['date']), (df_.loc[df_['date']==max(df_['date']), y]), str(country))
    plt.legend(loc="upper left", bbox_to_anchor=(1,1))
    if(is_log):
        ax.set(yscale="log")
    ax.grid(color='black', linestyle='dotted', linewidth=0.75)
    plt.show()  

In [None]:
filter_df = data_c_df
filter_df = filter_df.groupby(['geography_name', 'date'])['value'].sum().reset_index()
plot_time_variation(filter_df,is_log=True,title="railways accidents (per country)")