# Cluster model for world deaths 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
data = pd.read_csv("../input/worldwide-deaths-by-risk-factors/number-of-deaths-by-risk-factor.csv")
data.head()

## Reviewing the data to identify data type as well as shape and missing values.

In [None]:
data.info()

## There is a feature (High total cholesterol) that has a lot of missing data. I will drop the feature since it would the percentage of missing data is more than 70%.

In [None]:
data.drop('High total cholesterol', axis=1, inplace=True)

## Other than countries listed under the 'Entity' column, there are also come country groupings provided. They are:

* North America,
* Latin America and Caribbean,
* Central Europe,
* Eastern Europe,
* Western Europe,
* North Africa and Middle East,
* Central Sub-Saharan Africa,  
* Eastern Sub-Saharan Africa, 
* Western Sub-Saharan Africa,
* Southern Sub-Saharan Africa, 
* Central Asia, 
* East Asia, 
* Southeast Asia,
* South Asia,
* Australasia,
* High-income, and
* High-income Asia Pacific

This should cover all the countries provided in the data.
    
There is also 'World' data provided.
    
Other than by countries, there are also data grouped by Socio-demographic Index (SDI).
SDI is a summary measure that identifies where countries or other geographic areas sit on the spectrum of development. Expressed on a scale of 0 to 1, SDI is a composite average of the rankings of the incomes per capita, average educational attainment, and fertility rates of all areas in the GBD study.
A list can be found here http://ghdx.healthdata.org/record/ihme-data/gbd-2019-socio-demographic-index-sdi-1950-2019. There is also 2 other categories for high income countries (High-income and High-income Asia Pacific).

## Filter the dataframe to obtain data for countries only (not by set groupings).

In [None]:
remove = ['North America', 'Latin America and Caribbean', 'Central Europe', 
          'Eastern Europe', 'Western Europe','North Africa and Middle East', 
          'Central Sub-Saharan Africa', 'Eastern Sub-Saharan Africa',
          'Western Sub-Saharan Africa', 'Southern Sub-Saharan Africa', 
          'Central Asia', 'East Asia','Southeast Asia', 'South Asia', 
          'Australasia', 'Central Europe, Eastern Europe, and Central Asia',
          'Sub-Saharan Africa', 'Southeast Asia, East Asia, and Oceania',
          'Southern Latin America','Central Latin America', 'Tropical Latin America',
          'High SDI', 'High-middle SDI', 'Middle SDI','Low-middle SDI', 
          'Low SDI', 'High-income', 'High-income Asia Pacific', 'World']

country_df = data[~data['Entity'].isin(remove)]

In [None]:
country_df.head()

### The data will be grouped by the 'year' feature and the mean value will be used.

In [None]:
grouped_country_df = country_df.groupby('Entity').mean()
total_deaths = grouped_country_df.drop('Year', axis=1).sum().transpose().sort_values(ascending=False)

plt.figure(figsize=(10,8))
sns.barplot(y=total_deaths.index, x=total_deaths.values, orient='h')

plt.xticks(rotation=90);

### Let's visualize the correlation for just a few selected features.

In [None]:
sns.heatmap(grouped_country_df[['High fasting plasma glucose', 'High body-mass index', 
                                 'High systolic blood pressure', 'Air pollution', 'Smoking']].corr(), annot=True)

### The correlation values are high for all 4 features.

### What about the effects of diet?

In [None]:
sns.pairplot(grouped_country_df[['Diet high in sodium', 'Diet low in whole grains', 
                                 'Diet low in nuts and seeds', 'Diet low in fruits', 
                                 'Diet low in vegetables']], kind='reg', diag_kind='kde')

#### There are positive correlations for 4 of the feature (Diet low in whole grains, Diet low in nuts and seeds, Diet low in fruits and Diet low in vegetables) and the values shown are high as well.

#### Features with correlation value above 95% will be identified and dropped for cluster modelling.

*Note: I have also tried to perform clustering without removing any feature to check if this would impact the outcome but I'm happy to inform that the results were the same. However, having a large number of features will  reduce the interpretability of the model. *

In [None]:
corr_matrix = grouped_country_df.corr().abs()
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool_))
high_corr_col = [column for column in upper.columns if any(upper[column] > 0.95)]
country_feat = grouped_country_df.drop(high_corr_col, axis=1)

country_feat.head()

#### Total features are now reduced to 9 (not including Year which will be dropped) from the initial 28. K-Means clustering will be used and Sum of Squared Distances will be calculated to determine the cluster number.

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans

X = country_feat.drop('Year', axis=1)

scaler = StandardScaler()
scaled_X = scaler.fit_transform(X)

ssd= []
for k in range(2,16):
    model = KMeans(n_clusters=k)
    model.fit(scaled_X)
    
    ssd.append(model.inertia_)
    
# The elbow method will be used to determine the k value.  
plt.figure(figsize=(10,6))
plt.plot(range(2,16), ssd, 'o--')
plt.xlabel('k values')
plt.ylabel('Sum of Squared Distances')

In [None]:
# Selected n_clusters=6.
model = KMeans(n_clusters=6, random_state=0)
cluster_labels = model.fit_predict(scaled_X)

X['Cluster'] = cluster_labels
cluster_corr = X.corr()['Cluster'].sort_values()

### Which feature did K-Means regarded as most important?

In [None]:
plt.figure(figsize=(10,4))
sns.barplot(x=cluster_corr[:-1].index, y=cluster_corr[:-1].values)
plt.title("Feature importance determined by K-Means", fontsize=16)
plt.xlabel("Risk factors")
plt.ylabel("Correlation")
plt.xticks(rotation=90);

### Let's see how K-Means clustered the countries using choropleth map.

In [None]:
iso = pd.read_csv('../input/country-code/country_code.csv', index_col='Country')
iso_code = iso['3let'].to_dict()
X['ISO_Code'] = X.index.map(iso_code)

In [None]:
import plotly.express as px
import plotly.offline as pyo
pyo.init_notebook_mode()

fig = px.choropleth(X, locations="ISO_Code",
                    color="Cluster", 
                    hover_name=X.index, 
                    color_continuous_scale='Rainbow')
fig.show()

### Was there a reason why China and India/Myanmar were having their own cluster? 
### What about US being grouped with Russia and why was Canada was not grouped with US?

In [None]:
sel_countries = ['United States','Canada','China','India','Russia', 'Australia']

filt_countries = country_feat.loc[sel_countries].groupby('Entity').sum()
filt_countries = filt_countries.drop('Year', axis=1).groupby('Entity').sum().transpose()
filt_countries = filt_countries.reset_index()
filt_countries = pd.melt(filt_countries, 'index', var_name='country', value_name='value')

In [None]:
plt.figure(figsize=(12,8))
sns.lineplot(x='index', y='value', hue='country', data=filt_countries, palette='Dark2')
plt.legend(loc=(1.05,0.5))
plt.title("Total Deaths by Risk Factors for Selected Countries", fontsize=16)
plt.xlabel("Risk factors")
plt.ylabel("Total")
plt.xticks(rotation=90);

#### In an attempt to get some insights, the chart above was plotted to see if there were any information that can be gained to explain how K-Means clustered the countries. The chart provided a clear pattern on why US and Russia were clustered together and also why China and India are in their own cluster. Canada and Australia can also be seen to have similar pattern and was probably why they were clustered together.

### Let's explore further and look at the data grouped by Socio-demographic Index (SDI).



## Data grouped by Socio-demographic Index (SDI)

In [None]:
sdi_list = ['High SDI', 'High-middle SDI', 'Middle SDI', 'Low SDI', 'Low-middle SDI']

sdi = data[data['Entity'].isin(sdi_list)]

In [None]:
import matplotlib.lines as mlines

sdi1 = sdi[(sdi['Year']>2001) & (sdi['Year']<2011)].groupby('Entity').mean().sum(axis=1)/1000000
sdi2 = sdi[(sdi['Year']>2011) & (sdi['Year']<2017)].groupby('Entity').mean().sum(axis=1)/1000000

left_label = [str(c) + ', '+ str(round(y,2)) + 'mil' for c, y in zip(sdi1.index, sdi1.values)]
right_label = [str(c) + ', '+ str(round(y,2)) + 'mil' for c, y in zip(sdi2.index, sdi2.values)]
klass = ['red' if (y1-y2) < 0 else 'green' for y1, y2 in zip(sdi1.values, sdi2.values)]

def newline(p1, p2, color='black'):
    ax = plt.gca()
    l = mlines.Line2D([p1[0],p2[0]], [p1[1],p2[1]], color='red' if p1[1]-p2[1] > 0 else 'green', 
                      marker='o', markersize=6)
    ax.add_line(l)
    return l

fig, ax = plt.subplots(figsize=(14,14))

ax.vlines(x=1, ymin=8, ymax=16, color='black', alpha=0.7, linewidth=1, linestyles='dotted')
ax.vlines(x=3, ymin=8, ymax=16, color='black', alpha=0.7, linewidth=1, linestyles='dotted')

ax.scatter(y=sdi1.values, x=np.repeat(1, sdi1.shape[0]), s=10, color='black', alpha=0.7)
ax.scatter(y=sdi2.values, x=np.repeat(3, sdi2.shape[0]), s=10, color='black', alpha=0.7)

for p1, p2, c in zip(sdi1.values, sdi2.values, sdi2.index):
    newline([1,p1], [3,p2])
    ax.text(1-0.05, p1, c + ', ' + str(round(p1,2)) + 'mil', horizontalalignment='right', 
            verticalalignment='center', fontdict={'size':12})
    ax.text(3+0.05, p2, c + ', ' + str(round(p2,2)) + 'mil', horizontalalignment='left', 
            verticalalignment='center', fontdict={'size':12})
    
ax.text(1-0.05, 17, '2001-2010', horizontalalignment='right', verticalalignment='center', 
        fontdict={'size':18, 'weight':700})
ax.text(3+0.05, 17, '2011-2017', horizontalalignment='left', verticalalignment='center', 
        fontdict={'size':18, 'weight':700})

ax.set_title("Slopechart: Comparing Total Deaths by SDI", fontdict={'size':22})
ax.set(xlim=(0,4), ylim=(8, 18), ylabel='Average Total Deaths (million)')
ax.set_xticks([1,3])
ax.set_xticklabels(["", ""])
plt.yticks(np.arange(8, 18, 2), fontsize=12)

plt.gca().spines["top"].set_alpha(.0)
plt.gca().spines["bottom"].set_alpha(.0)
plt.gca().spines["right"].set_alpha(.0)
plt.gca().spines["left"].set_alpha(.0)
plt.show()

#### It was observed that countries with Low SDIs has seen a decreased in their death rates while other SDIs had an increase. Countries grouped as Middle SDI had the biggest increase in death count. Let's explore the Low SDI and Middle SDI to get some insights.

In [None]:
low_sdi = sdi[sdi['Entity']=='Low SDI']

low_2001_2010 = low_sdi[(low_sdi['Year']>2001) & (low_sdi['Year']<2011)].groupby('Entity').mean()
low_2011_2017 = low_sdi[(low_sdi['Year']>2011) & (low_sdi['Year']<2017)].groupby('Entity').mean()


plt.figure(figsize=(10,8))
diff = (low_2011_2017 - low_2001_2010).drop('Year', axis=1).transpose()
sns.barplot(x=diff.index,y=diff['Low SDI'],data=diff)
plt.xlabel("Risk factors")
plt.ylabel("Total")
plt.xticks(rotation=90);

#### Improvements to sanitation as well as safer sex activities seemed to have improve thesurvivability of people in low SDI countries.

In [None]:
mid_sdi = sdi[sdi['Entity']=='Middle SDI']

mid_2001_2010 = mid_sdi[(mid_sdi['Year']>2001) & (mid_sdi['Year']<2011)].groupby('Entity').mean()
mid_2011_2017 = mid_sdi[(mid_sdi['Year']>2011) & (mid_sdi['Year']<2017)].groupby('Entity').mean()


plt.figure(figsize=(10,8))
diff = (mid_2011_2017 - mid_2001_2010).drop('Year', axis=1).transpose()
sns.barplot(x=diff.index,y=diff['Middle SDI'],data=diff)
plt.xlabel("Risk factors")
plt.ylabel("Total")
plt.xticks(rotation=90);

#### Higher purchasing power as well as access to modern conveniences have certainly increased the death count for middle SDI countires as risk factors related to diet as well as lifestyle have increased significantly. 

#### Let's explore more from the data.

In [None]:
sdi_deaths = sdi.drop('Year', axis=1).groupby('Entity').sum().transpose()
sdi_deaths = sdi_deaths.reset_index()
sdi_melt = pd.melt(sdi_deaths, 'index', var_name='count', value_name='value')

plt.figure(figsize=(12,8))
sns.lineplot(x='index', y='value', hue='count', data=sdi_melt, palette='Dark2')
plt.legend(loc=(1.05,0.5))
plt.title("Total Deaths by Risk Factors for SDIs", fontsize=16)
plt.xlabel("Risk factors")
plt.ylabel("Total")
plt.xticks(rotation=90);

#### High blood pressure is still the main contributor to the highest number of deaths. However, this seems to be impacting High, High-Middle and Middle SDI countries more than Low and Low-Middle SDI countries. For the Low and Low-Middle SDI countries, they have far higher death total caused by children and sanitation related risk factors. This, sadly, seems to make sense.

#### Cluster modeling will not be explored for SDIs.

#### There are more to explore with this dataset as there are other groups that can be studied as well. However, I will leave that for others to explore as it is time for me to work on a 'happier' dataset.