In [1]:
import numpy as np
import pandas as pd
import feather
import plotly
from tqdm import tqdm
import pickle

In [2]:
import plotly.plotly as py
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from plotly.grid_objs import Grid, Column
from plotly.graph_objs import *
from IPython.display import display, HTML
from datetime import datetime, timedelta

In [3]:
init_notebook_mode(connected=True)

# Get dataframe

In [4]:
antiflu = feather.read_dataframe('..\\..\\Data\\Datathon_2017\\antibiotic_sample_data')

In [5]:
postcode_populations_lat_lon = feather.read_dataframe('..\\..\\Data\\Datathon_2017\\postcode_populations_lat_lon')

In [6]:
antiflu.columns

Index(['Patient_ID', 'Store_ID', 'Prescriber_ID', 'Drug_ID',
       'SourceSystem_Code', 'Prescription_Week', 'Dispense_Week', 'Drug_Code',
       'NHS_Code', 'IsDeferredScript', 'Script_Qty', 'Dispensed_Qty',
       'MaxDispense_Qty', 'PatientPrice_Amt', 'WholeSalePrice_Amt',
       'GovernmentReclaim_Amt', 'RepeatsTotal_Qty', 'RepeatsLeft_Qty',
       'StreamlinedApproval_Code', 'MasterProductID',
       'MasterProductFullName.x', 'BrandName.x', 'FormCode.x',
       'StrengthCode.x', 'PackSizeNumber.x', 'GenericIngredientName.x',
       'EthicalSubCategoryName.x', 'EthicalCategoryName.x',
       'ManufacturerCode.x', 'ManufacturerName.x', 'ManufacturerGroupID.x',
       'ManufacturerGroupCode.x', 'ChemistListPrice.x', 'ATCLevel5Code.x',
       'ATCLevel4Code.x', 'ATCLevel3Code.x', 'ATCLevel2Code.x',
       'ATCLevel1Code.x', 'ATCLevel1Name', 'ATCLevel2Name', 'ATCLevel3Name',
       'ATCLevel4Name', 'ATCLevel5Name', 'gender', 'year_of_birth', 'postcode',
       'lat', 'long', 'MasterP

In [7]:
antiflu[['Patient_ID','lat','long','total_pop']].head(2)

Unnamed: 0,Patient_ID,lat,long,total_pop
0,25.0,-36.040463,146.932638,14377.0
1,25.0,-36.040463,146.932638,14377.0


In [8]:
len(antiflu)

484964

# init_params

In [9]:
grid_res = 50
lat_spacing = [ -42, -25 ]
lon_spacing = [ 135, 155 ]

In [10]:
lat_grid = np.linspace(lat_spacing[0],lat_spacing[1],grid_res)
lon_grid = np.linspace(lon_spacing[0],lon_spacing[1],grid_res)

In [11]:
def between(x,num1,num2):
    if (x < num2 and x >= num1):
        return(True)
    else:
        return(False)

In [12]:
def getRecordsBetweenDates(date_1,date_2):
    return antiflu[antiflu.apply( lambda x: between(x['Prescription_Week'],date_1,date_2),axis=1 ) ]

In [169]:
def addDaysToDate(date,days_num):
    return datetime.strftime((datetime.strptime(date,'%Y-%m-%d') + timedelta(days=days_num)),'%Y-%m-%d')

def addMonthsToDate(date,months_num):
    curr_date = datetime.strptime(date,'%Y-%m-%d')
    return datetime.strftime(datetime(curr_date.year, curr_date.month+1,curr_date.day),'%Y-%m-%d')

# 1 years worth of data here

In [170]:
date_list = [['2013-05-01','2013-06-01']]

for i in range(4):
    date_list.append([date_list[-1][1], addMonthsToDate(date_list[-1][1],1)])

In [171]:
date_list

[['2013-05-01', '2013-06-01'],
 ['2013-06-01', '2013-07-01'],
 ['2013-07-01', '2013-08-01'],
 ['2013-08-01', '2013-09-01'],
 ['2013-09-01', '2013-10-01']]

In [172]:
getRecordsBetweenDates('2013-05-02','2013-05-10')

Unnamed: 0,Patient_ID,Store_ID,Prescriber_ID,Drug_ID,SourceSystem_Code,Prescription_Week,Dispense_Week,Drug_Code,NHS_Code,IsDeferredScript,...,ATCLevel4Code.y,ATCLevel3Code.y,ATCLevel2Code.y,ATCLevel1Code.y,Prescription_Year,total_pop,working_age_pct,lat_index,lon_index,index_matrix_coords
362,1434.0,2414,39276,7439.0,False,2013-05-05,2013-07-14,RES5,2951H,0,...,J01EE,J01E,J01,J,2013.0,27535.0,68.9,27,41,2077
363,1434.0,2414,39276,7439.0,False,2013-05-05,2013-08-11,RES5,2951H,0,...,J01EE,J01E,J01,J,2013.0,27535.0,68.9,27,41,2077
364,1434.0,2414,39276,7439.0,False,2013-05-05,2013-09-15,RES5,2951H,0,...,J01EE,J01E,J01,J,2013.0,27535.0,68.9,27,41,2077
1110,5241.0,555,326,2718.0,False,2013-05-05,2013-06-02,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22536.0,64.8,24,39,1974
1111,5241.0,555,326,2718.0,False,2013-05-05,2013-06-30,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22536.0,64.8,24,39,1974
1112,5241.0,555,326,2718.0,False,2013-05-05,2013-07-28,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22536.0,64.8,24,39,1974
1410,6298.0,131,9484,2718.0,False,2013-05-05,2013-05-26,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,30408.0,67.8,12,25,1262
1411,6298.0,131,9484,2718.0,False,2013-05-05,2013-06-09,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,30408.0,67.8,12,25,1262
1412,6298.0,131,9484,2718.0,False,2013-05-05,2013-06-23,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,30408.0,67.8,12,25,1262
1490,6568.0,1768,35320,2705.0,False,2013-05-05,2013-05-05,DOX7,2711Q,0,...,J01AA,J01A,J01,J,2013.0,28746.0,88.9,23,40,2023


# Generate Data

In [17]:
pop_matrix = np.zeros((grid_res,grid_res))

tmp_pop_data = postcode_populations_lat_lon[postcode_populations_lat_lon['YEAR'] == int(date_list[-1][0][0:4])]

In [18]:
    for i in tqdm(range(grid_res-1)):
        for j in range(grid_res-1):
            pop_matrix[i,j] = tmp_pop_data[tmp_pop_data.apply(lambda x: ( between(x['lat'], lat_grid[i], lat_grid[i+1]) and between(x['long'], lon_grid[j], lon_grid[j+1]) ) ,axis=1) ]['total_pop'].sum()        

100%|██████████████████████████████████████████████████████████████████████████████████| 49/49 [01:09<00:00,  1.39s/it]


In [19]:
index_matrix = np.array(range(50*50)).reshape(50,50)

In [66]:
def getFirstLatCoord(lat):
    try:
        return next(x[0] for x in enumerate(lat_grid) if x[1] > lat)
    except:
        return -1

def getFirstLonCoord(lon):
    try:
        return next(x[0] for x in enumerate(lon_grid) if x[1] > lon)
    except:
        return -1

In [67]:
antiflu['lat_index'] = antiflu.apply(lambda x: getFirstLatCoord(x['lat']), axis = 1)
antiflu['lon_index'] = antiflu.apply(lambda x: getFirstLonCoord(x['long']), axis = 1)

In [68]:
antiflu = antiflu[antiflu['lat_index'] != -1]
antiflu = antiflu[antiflu['lon_index'] != -1]

In [69]:
antiflu['index_matrix_coords'] = antiflu['lat_index'] + 50*antiflu['lon_index']

In [70]:
def GenerateCountMatrixEfficiently(df):

    count_matrix = []

    for x in index_matrix:
        for y in x:
            count_matrix.append(df[df['index_matrix_coords'] == y]['Patient_ID'].count())
        
    return np.array(count_matrix).reshape(50,50).T

In [71]:
#df_agg = pd.DataFrame(tmp_df.groupby('index_matrix_coords').size().rename('vol'))

In [72]:
def GenerateCountMatrixEfficientlyV2(df):

    count_matrix = []
    df_agg = pd.DataFrame(df.groupby('index_matrix_coords').size().rename('vol'))

    for x in index_matrix:
        for y in x:
            try:
                count_matrix.append(int(df_agg[df_agg.index == y]['vol']))
            except:
                count_matrix.append(0)
        
    return np.array(count_matrix).reshape(50,50)

In [73]:
def generateCountMatrix(tmp_df):
    
    count_matrix = np.zeros((grid_res,grid_res))
    for i in range(grid_res-1):
        for j in range(grid_res-1):
            count_matrix[i,j] = tmp_df[tmp_df.apply(lambda x: ( between(x['lat'], lat_grid[i], lat_grid[i+1]) and between(x['long'], lon_grid[j], lon_grid[j+1]) ) ,axis=1) ]['Patient_ID'].count()
    return count_matrix

# test

In [74]:
tmp_df = getRecordsBetweenDates(date_list[0][0],date_list[0][1])

In [75]:
len(tmp_df)

3776

In [76]:
count_matrix = GenerateCountMatrixEfficientlyV2(tmp_df)

In [77]:
np.sum(count_matrix)

3776

In [78]:
def div0( a, b ):
    """ ignore / 0, div0( [-1, 0, 1], 0 ) -> [0, 0, 0] """
    with np.errstate(divide='ignore', invalid='ignore'):
        c = np.true_divide( a, b )
        c[ ~ np.isfinite( c )] = 0  # -inf inf NaN
    return c.T

In [79]:
antiflu_ratio_matrix = div0(count_matrix, pop_matrix)

In [80]:
trace_heat = Heatmap(z = antiflu_ratio_matrix)

In [81]:
#iplot([trace_heat])

In [82]:
grid_res

50

In [83]:
#scale needs to be ebtween 0 and 1
colscl = [
        # Let first 10% (0.1) of the values have color rgb(0, 0, 0)
        [0, 'rgb(255, 255, 255)'],
        [0.2, 'rgb(255, 255, 255)'],

        # Let values between 10-20% of the min and max of z
        # have color rgb(20, 20, 20)
        [0.2, 'rgb(255, 255, 255)'],
        [0.4, 'rgb(255, 216, 32)'],

        # Values between 20-30% of the min and max of z
        # have color rgb(40, 40, 40)
        [0.4, 'rgb(255, 216, 32)'],
        [0.8, 'rgb(234, 35, 0)'],

        [0.8, 'rgb(234, 35, 0)'],
        [1, 'rgb(0, 0, 0)']
    ]

In [84]:
tmp_df

Unnamed: 0,Patient_ID,Store_ID,Prescriber_ID,Drug_ID,SourceSystem_Code,Prescription_Week,Dispense_Week,Drug_Code,NHS_Code,IsDeferredScript,...,ATCLevel4Code.y,ATCLevel3Code.y,ATCLevel2Code.y,ATCLevel1Code.y,Prescription_Year,total_pop,working_age_pct,lat_index,lon_index,index_matrix_coords
362,1434.0,2414,39276,7439.0,False,2013-05-05,2013-07-14,RES5,2951H,0,...,J01EE,J01E,J01,J,2013.0,27535.0,68.9,27,41,2077
363,1434.0,2414,39276,7439.0,False,2013-05-05,2013-08-11,RES5,2951H,0,...,J01EE,J01E,J01,J,2013.0,27535.0,68.9,27,41,2077
364,1434.0,2414,39276,7439.0,False,2013-05-05,2013-09-15,RES5,2951H,0,...,J01EE,J01E,J01,J,2013.0,27535.0,68.9,27,41,2077
591,2344.0,1706,8541,223.0,False,2013-05-12,2013-08-11,AKAM1,1616C,0,...,J01AA,J01A,J01,J,2013.0,17657.0,66.7,24,39,1974
592,2344.0,1706,8541,223.0,False,2013-05-12,2013-09-08,AKAM1,1616C,0,...,J01AA,J01A,J01,J,2013.0,17657.0,66.7,24,39,1974
593,2344.0,1706,8541,223.0,False,2013-05-12,2013-10-13,AKAM1,1616C,0,...,J01AA,J01A,J01,J,2013.0,17657.0,66.7,24,39,1974
775,3396.0,1592,18944,2718.0,False,2013-05-12,2013-07-28,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22739.0,64.8,12,25,1262
776,3396.0,1592,18944,2718.0,False,2013-05-12,2013-08-11,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22739.0,64.8,12,25,1262
1110,5241.0,555,326,2718.0,False,2013-05-05,2013-06-02,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22536.0,64.8,24,39,1974
1111,5241.0,555,326,2718.0,False,2013-05-05,2013-06-30,DOXY11,2711Q,0,...,J01AA,J01A,J01,J,2013.0,22536.0,64.8,24,39,1974


In [181]:

data_pts = []
slider_steps = []
original_dfs = []

for date_item in tqdm(date_list):

    tmp_df = getRecordsBetweenDates(date_item[0],date_item[1])

    x = np.transpose(np.array(tmp_df[['long']]))[0]
    y = np.transpose(np.array(tmp_df[['lat']]))[0]

    xv, yv = np.meshgrid(lat_grid[:(grid_res)], lon_grid[:(grid_res)])

    #H, xedges, yedges = np.histogram2d(x,y,bins=(lon_grid,lat_grid))
    count_matrix = GenerateCountMatrixEfficiently(tmp_df)

    H = div0(count_matrix, pop_matrix)*1000

    H = H.reshape((grid_res)*(grid_res),1)
    xv = xv.reshape((grid_res)*(grid_res),1)
    yv = yv.reshape((grid_res)*(grid_res),1)
    
    perc = np.percentile(H,99.5)
    #H = np.array([np.min([i[0],perc]) for i in H])
    #H = H.reshape((grid_res)*(grid_res),1)
    
    df = pd.DataFrame(np.concatenate((H,xv,yv), axis = 1), columns=['density','lat','lon'])
    df_reduced = df[df['density'] > 0]
    df_reduced = df_reduced[df_reduced['density'] < 0.95*perc]
    #df_reduced = df
    
    
    data_pt = [ dict(
        lat = df_reduced['lat'],
        lon = df_reduced['lon'],
        text = df_reduced['density'].astype(str),
        marker = dict(
            symbol = "square-dot",
            color = df_reduced['density'],
            colorscale= colscl,
            reversescale = False,
            opacity = 0.7,
            size = 8,
            colorbar = dict(
                thickness = 10,
                titleside = "right",
                outlinecolor = "rgb(212,212,212)",
                ticks = "outside",
                ticklen = 3,
                showticksuffix = "last",
                dtick = 0.5
            ),
        ),
        type = 'scattergeo'
    ) ]
    
    slider_step = {
        'args': [
            [date_item[0]],
        {'frame': {'duration': 300, 'redraw': False},
         'mode': 'immediate',
       'transition': {'duration': 300}}
                        ],
                        'label': date_item[0],
                        'method': 'animate'
                    }
    
    if (np.sum(H) > 0):
        original_dfs.append(df_reduced)
    
    if (np.sum(H) > 0):
        data_pts.append({'data' : data_pt, 'name' : date_item[0] })
        slider_steps.append(slider_step)

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:49<00:00,  9.94s/it]


# total histogram

In [182]:
data = [ dict(
    lat = df_reduced['lat'],
    lon = df_reduced['lon'],
    text = df_reduced['density'].astype(str),
    marker = dict(
        color = df_reduced['density'],
        colorscale= colscl,
        reversescale = False,
        opacity = 0.7,
        size = 10,
        colorbar = dict(
            thickness = 10,
            titleside = "right",
            outlinecolor = "rgb(212,212,212)",
            ticks = "outside",
            ticklen = 3,
            showticksuffix = "last",
            dtick = 0.1
        ),
    ),
    type = 'scattergeo'
) ]

figure = {'data': data_pts[0]['data'],
         'layout': dict(
                            geo = dict(
                            scope = 'world',
                            showland = True,
                            landcolor = "rgb(212, 212, 212)",
                            subunitcolor = "rgb(255, 255, 255)",
                            countrycolor = "rgb(255, 255, 255)",
                            showlakes = True,
                            lakecolor = "rgb(255, 255, 255)",
                            showsubunits = True,
                            showcountries = True,
                            resolution = grid_res,
                            lonaxis = dict(
                                showgrid = True,
                                gridwidth = 0.5,
                                range= lon_spacing,
                                dtick = (lon_spacing[1]-lon_spacing[0])/grid_res*10
                            ),
                            lataxis = dict (
                                showgrid = True,
                                gridwidth = 0.5,
                                range= lat_spacing,
                                dtick = (lat_spacing[1]-lat_spacing[0])/grid_res*10
                            )
                            ),
        updatemenus = [{'type': 'buttons',
                                  'buttons': [{'label': 'Play',
                                               'method': 'animate',
                                               'args': [None, {'frame': {'duration': 500, 'redraw': False},
                                                 'fromcurrent': True, 'transition': {'duration': 300, 'easing': 'quadratic-in-out'}}]},
                                              {
                                                'args': [[None], {'frame': {'duration': 0, 'redraw': False}, 'mode': 'immediate',
                                                'transition': {'duration': 0}}],
                                                'label': 'Pause',
                                                'method': 'animate'
                                            }
                                             ],
                        'direction': 'left',
                        'pad': {'r': 10, 't': 87},
                        'showactive': False,
                        'type': 'buttons',
                        'x': 0.1,
                        'xanchor': 'right',
                        'y': 0,
                        'yanchor': 'top'
                       }],
        sliders = [{
                'active': 0,
                'yanchor': 'top',
                'xanchor': 'left',
                'currentvalue': {
                    'font': {'size': 20},
                    'prefix': 'density as at: ',
                    'visible': True,
                    'xanchor': 'right'
                },
                'transition': {'duration': 300, 'easing': 'cubic-in-out'},
                'pad': {'b': 10, 't': 50},
                'len': 0.9,
                'x': 0.1,
                'y': 0,
                'steps': slider_steps
            }],
        
        title = 'East Australia antibiotic scatter map'
        ),
          
        'frames' : data_pts
         }


iplot(figure)

In [118]:
for i in range(len(slider_steps)):
    slider_steps[i]['args'] = [None]

In [119]:
slider_steps[0]

{'args': [None], 'label': 'label-for-frame', 'method': 'animate'}

In [125]:
len(data_pts)

17

In [98]:
pickle.dump( data_pts, open( '..\\..\\Data\\Datathon_2017\\density_grid.pkl', 'wb' ) )

'2017-06-27'