In [17]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

In [18]:
from dataset.world_bank import CountyDataSet, Metadata
dataset = CountyDataSet(country_code='UKR')
metadata = Metadata()

In [19]:
metadata.data

Unnamed: 0,Country,Country Code,IncomeGroup,Region
0,Aruba,ABW,High income,Latin America & Caribbean
1,Afghanistan,AFG,Low income,South Asia
2,Angola,AGO,Lower middle income,Sub-Saharan Africa
3,Albania,ALB,Upper middle income,Europe & Central Asia
4,Andorra,AND,High income,Europe & Central Asia
...,...,...,...,...
212,Kosovo,XKX,Upper middle income,Europe & Central Asia
213,"Yemen, Rep.",YEM,Low income,Middle East & North Africa
214,South Africa,ZAF,Upper middle income,Sub-Saharan Africa
215,Zambia,ZMB,Lower middle income,Sub-Saharan Africa


In [20]:
def calc_correlation_matrix(df: pd.DataFrame, threshold: float = 0.6):
    mat = df.corr()
    mat[
        ((mat > 0) & (mat < threshold)) | # value is positive and less than threshold
        ((mat < 0) & (mat > -threshold)) | # value is negative and the abs value is less than threshold
        (mat != mat) # value is NaN (not a number)
    ] = 0 # set all values that do not satisfy the threshold to zero
    return mat

In [21]:
### calculating the correlation between different indicators ###
corr_matrix = calc_correlation_matrix(dataset.by_indicator_codes.T, 0.85)
corr_matrix

Unnamed: 0,VC.IHR.PSRC.FE.P5,VC.BTL.DETH,VA.PER.RNK,TX.VAL.TRAN.ZS.WT,TX.VAL.OTHR.ZS.WT,TX.VAL.MRCH.RS.ZS,TX.VAL.MRCH.R3.ZS,TX.VAL.MRCH.HI.ZS,TX.VAL.MANF.ZS.UN,TX.VAL.FOOD.ZS.UN,...,SE.PRM.ENRL,SE.PRM.CUAT.FE.ZS,SE.PRM.AGES,SE.PRE.ENRR.MA,SE.PRE.DURS,SE.LPV.PRIM.MA,SE.LPV.PRIM.FE,SE.ENR.PRSC.FM.ZS,SE.ADT.LITR.MA.ZS,SE.ADT.1524.LT.FM.ZS
VC.IHR.PSRC.FE.P5,1.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.000000,0.000000
VC.BTL.DETH,0.0,1.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.000000,0.892783,0.0,0.0,0.0,0.0,0.000000,0.000000
VA.PER.RNK,0.0,0.0,1.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,1.000000,1.000000
TX.VAL.TRAN.ZS.WT,0.0,0.0,0.0,1.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.873272,0.000000,0.0,0.0,0.0,0.0,-0.932342,0.000000
TX.VAL.OTHR.ZS.WT,0.0,0.0,0.0,0.000000,1.000000,0.0,0.0,0.0,-0.947331,0.92973,...,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.000000,0.907717
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SE.LPV.PRIM.MA,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.000000,0.000000
SE.LPV.PRIM.FE,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.000000,0.000000
SE.ENR.PRSC.FM.ZS,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.00000,...,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,1.0,-1.000000,1.000000
SE.ADT.LITR.MA.ZS,0.0,0.0,1.0,-0.932342,0.000000,-1.0,1.0,-1.0,0.000000,0.88353,...,-0.954626,0.0,-0.999786,1.000000,0.0,0.0,0.0,-1.0,1.000000,0.000000


In [22]:
def show_correlation_heatmap(mat: pd.DataFrame, title='Correlation Heatmap'):
    fig = px.imshow(mat, x=mat.columns, y=mat.index, color_continuous_scale='RdYlGn')
    fig.update_layout(title=title, width=len(mat.columns), height=len(mat.index), autosize=False)
    fig.show()

In [None]:
# AHTUNG! creating this heatmap will add 30-60 MB to this notebook file
show_correlation_heatmap(corr_matrix)

Calculated heatmap for the correlation between different indicators with threshold of 85%:

![heatmap image](./heatmap_RdYlGn_threshold85.png)