In [1]:
import pandas as pd
import numpy as np
import joblib as jl
from tqdm import tqdm
import multiprocessing as mp

In [None]:
dd = jl.load(open('../tmp/auxfile_dd_.pkl', 'rb'))
dd.head()

In [None]:
metonly = set(pd.read_csv('../mbspbs10pc/data/metformin_items.csv', header=0)['ITM_CD'].values.tolist())
metx = set(pd.read_csv('../mbspbs10pc/data/metformin+x_items.csv', header=0)['ITM_CD'].values.tolist())

In [None]:
# METONLY
grouped = dd.groupby(by='PTNT_ID')
filtered = grouped.filter(lambda x: set(x['ITM_CD']).issubset(metonly)).groupby(by='PTNT_ID')
idx = []
start_date = []
for name, group in filtered:
    idx.append(name)
    start_date.append(group['SPPLY_DT'].min().strftime('%Y-%m-%d'))
print(len(idx))

In [None]:
# MET+X
min_metformin = 10  # metformin only

def condition(group):
    if len(group) <= min_metformin:
        return False
    else:
        sorted_group = group.sort_values(by='SPPLY_DT')
        head = set(sorted_group.head(min_metformin)['ITM_CD'].values) # these must all be metformin
        tail = set(sorted_group.tail(-min_metformin)['ITM_CD'].values) # there should be both in here

        cond1 = head.issubset(metonly)
        cond2 = len(tail.intersection(metonly)) > 0 and len(tail.intersection(metonly)) < len(tail)
        cond3 = len(tail.intersection(metx)) > 0 # handling special case of met+x items introduced after 2014

        return cond1 and (cond2 or cond3)
    

grouped = dd.groupby(by='PTNT_ID')
filtered = grouped.filter(condition).groupby(by='PTNT_ID')

# Init return items
idx, start_date, end_date = list(), list(), list()

# Build output variables
for name, group in tqdm(filtered, desc='Finalizing', leave=False):
    idx.append(name)
    
    sorted_group = group.sort_values(by='SPPLY_DT')
    start_date.append(sorted_group.head(1)['SPPLY_DT'].max().strftime('%Y-%m-%d')) # first date FIXME: max() added for consistency
    filtered_group = sorted_group[~sorted_group['ITM_CD'].isin(metonly)] # get the non metformins
    end_date.append(filtered_group['SPPLY_DT'].min().strftime('%Y-%m-%d'))

In [None]:
# MET2X
def condition(group):
    if len(group) <= min_metformin:
        return False
    else:
        sorted_group = group.sort_values(by='SPPLY_DT')
        head = set(sorted_group.head(min_metformin)['ITM_CD'].values) # these must all be metformin
        tail = set(sorted_group.tail(-min_metformin)['ITM_CD'].values) # there should be NO metformin in here

        cond1 = head.issubset(metonly)
        cond2 = len(tail.intersection(metonly)) == 0
        cond3 = len(tail.intersection(metx)) == 0

        return cond1 and cond2 and cond3
    

grouped = dd.groupby(by='PTNT_ID')
filtered = grouped.filter(condition).groupby(by='PTNT_ID')

In [None]:
# Init return items
idx, start_date, end_date = list(), list(), list()

# Build output variables
for name, group in tqdm(filtered, desc='Finalizing', leave=False):
    idx.append(name)
    
    sorted_group = group.sort_values(by='SPPLY_DT')
    start_date.append(sorted_group.head(1)['SPPLY_DT'].max().strftime('%Y-%m-%d')) # first date FIXME: max() added for consistency
    filtered_group = sorted_group[~sorted_group['ITM_CD'].isin(metonly)] # get the non metformins
    end_date.append(filtered_group['SPPLY_DT'].min().strftime('%Y-%m-%d'))

In [None]:
labels = pd.read_csv('../tmp/labels.csv', header=0, index_col=0)
labels.shape[0] - labels.dropna().shape[0]

In [None]:
ids = dd['PTNT_ID']
idx012 = labels.dropna().index
idx_other = set(ids) - set(idx012)

In [None]:
grouped = dd.groupby(by='PTNT_ID')
filtered = grouped.filter(lambda x: x['PTNT_ID'].values[0] in idx_other).groupby(by='PTNT_ID')

In [None]:
start_date, end_date = list(), list()

# Build output variables
for name, group in tqdm(filtered, desc='Finalizing', leave=False):    
    start_date.append(group['SPPLY_DT'].min().strftime('%Y-%m-%d'))
    end_date.append(group['SPPLY_DT'].max().strftime('%Y-%m-%d'))

In [None]:
set(pd.read_csv('../mbspbs10pc/data/pregnancy_items.csv',usecols=['ITEM'])['ITEM'])

In [None]:
np.array_split(labels['START_DATE'].values, 3)

# raw-data-extr

In [2]:
import os
import datetime
source = '../tmp/labels.csv'
home = ['../mbspbs10pc']
exclude_pregnancy = True
sample_pin_lookout = '../../../data/SAMPLE_PIN_LOOKUP.csv'
mbs_files = filter(lambda x: 'MBS' in x, os.listdir('../../../data'))
mbs_files = [os.path.join('..', '..', '..', 'data', x) for x in mbs_files]

In [3]:
raw_data = dict()

# Step 0: load the source file, the btos4d file and the diabetes drugs file
dfs = pd.read_csv(source, header=0, index_col=0)
dfs['PTNT_ID'] = dfs.index  # FIXME: this is LEGACY CODE
btos4d = pd.read_csv(os.path.join(home[0], 'data', 'btos4d.csv'), header=0,
                     usecols=['ITEM', 'BTOS-4D'])

# check weather or not exclude pregnant subjects
if exclude_pregnancy:
    pregnancy_items = pd.read_csv(os.path.join(home[0], 'data', 'pregnancy_items.csv'),
                                  header=0, usecols=['ITEM'])
    pregnancy_items = set(pregnancy_items['ITEM'])

# Step 1: get sex and age
df_pin_lookout = pd.read_csv(sample_pin_lookout, header=0)
df_pin_lookout['AGE'] = datetime.datetime.now().year - df_pin_lookout['YOB']  # this is the age as of TODAY
dfs = pd.merge(dfs, df_pin_lookout, how='left', left_on='PTNT_ID', right_on='PIN')[['PIN', 'SEX', 'AGE', 'START_DATE', 'END_DATE', 'YOB']]
dfs = dfs.set_index('PIN')  # set PIN as index (easier access below)
# SPPLY_DT is the date of the FIRST diabetes drug supply

In [4]:
dfs.loc[1866585844]

SEX                    M
AGE                   75
START_DATE    2008-01-19
END_DATE      2014-09-11
YOB                 1943
Name: 1866585844, dtype: object

In [5]:
# Step 2: follow each patient in the mbs files
# at first create a very large dictionary with all the MBS files
# (keeping only the relevant columns)
# It is possible here to exclude pregnant subjects
mbs_df = pd.DataFrame(columns=['PIN', 'ITEM', 'DOS', 'PINSTATE'])
for mbs in tqdm(mbs_files, desc='MBS files loading', leave=False):
    dd = pd.read_csv(mbs, header=0, usecols=['PIN', 'ITEM', 'DOS', 'PINSTATE'], engine='c')
    if exclude_pregnancy: dd = dd.loc[~dd['ITEM'].isin(pregnancy_items), :]
    dd = dd.loc[dd['PIN'].isin(dfs.index), :]  # keep only the relevant samples
    dd = pd.merge(dd, btos4d, how='left', on='ITEM') # get the BTOS4D
    mbs_df = pd.concat((mbs_df, dd))
mbs_df.loc[:, 'DOS'] = pd.to_datetime(mbs_df['DOS'], format='%d%b%Y')

                                                                

In [46]:
mbs_df.head(3)

Unnamed: 0,BTOS-4D,DOS,ITEM,PIN,PINSTATE
0,G,2010-09-08,23,1866585844,2
1,G,2010-10-31,52,5065251748,4
2,T,2010-07-14,10990,7806690482,1


In [47]:
grouped = mbs_df.groupby('PIN').filter(lambda x: len(x)>1).groupby('PIN')

In [48]:
group = grouped.get_group(1866585844)
group.head(3)

Unnamed: 0,BTOS-4D,DOS,ITEM,PIN,PINSTATE
0,G,2010-09-08,23,1866585844,2
11,G,2010-10-03,23,1866585844,2
22,G,2010-04-28,23,1866585844,2


In [49]:
MIN_SEQ_LENGTH = 10

# def extract_sequence(group, dfs=None):
def extract_sequence(group):
    out = pd.DataFrame(columns=['seq', 'avg_year', 'last_pinstate'])
    pin = group.PIN.values[0]
    tmp = group.sort_values(by='DOS')  # sort by DOS
    start_date = dfs.loc[pin]['START_DATE']  # get start date
    end_date = dfs.loc[pin]['END_DATE']  # get end date
    # select sequence timespan
    tmp = tmp.loc[np.logical_and(tmp['DOS'] >= start_date, tmp['DOS'] <= end_date), :]
    
    if tmp.shape[0] > MIN_SEQ_LENGTH:  # keep only non-trivial sequencences
        # evaluate the first order difference and convert each entry in WEEKS
        timedeltas = map(lambda x: pd.Timedelta(x).days,
                         tmp['DOS'].values[1:] - tmp['DOS'].values[:-1])
        # use the appropriate encoding
        timedeltas = map(timespan_encoding, timedeltas)
        # then build the sequence as ['exam', idle-days, 'exam', idle-days, ...]
        seq = flatten([[btos, dt] for btos, dt in zip(tmp['BTOS-4D'].values, timedeltas)])
        seq.append(tmp['BTOS-4D'].values.ravel()[-1])  # add the last exam (ignored by zip)
        # and finally collapse everything down to a string like 'G0G1H...'
        seq = ''.join(map(str, seq))
        # compute the average age during the treatment by computing the average year
        avg_year = np.mean(pd.DatetimeIndex(tmp['DOS'].values.ravel()).year)
        # extract the last pinstate
        last_pinstate = tmp['PINSTATE'].values.ravel()[-1]
    else:
        seq,avg_year, last_pinstate = np.nan, np.nan, np.nan

    return pd.Series({'seq': seq, 'avg_year': avg_year, 'last_pinstate': last_pinstate})

In [50]:
def timespan_encoding(days):
    """Convert the input days in the desired timespan encoding.

    This function follows this encoding:
    --------------------------------
    Time duration        | Encoding
    --------------------------------
    [same day - 2 weeks] | 0
    (2 weeks  - 1 month] | 1
    (1 month  - 3 monts] | 2
    (3 months - 1 year]  | 3
    more than 1 year     | 4
    --------------------------------

    Parameters:
    --------------
    days: int
        The number of days between any two examinations.

    Returns:
    --------------
    enc: string
        The corresponding encoding.
    """
    if days < 0:
        raise ValueError('Unsupported negative timespans')
    elif days >= 0 and days <= 14:
        enc = 0
    elif days > 14 and days <= 30:  # using the "economic" month duration
        enc = 1
    elif days > 30 and days <= 90:  # using the "economic" month duration
        enc = 2
    elif days > 90 and days <= 360:  # using the "economic" year duration
        enc = 3
    else:
        enc = 4
    return str(enc)

def flatten(x):
    """Flatten a list."""
    return [y for l in x for y in flatten(l)] \
        if type(x) in (list, np.ndarray) else [x]

In [51]:
# from joblib import Parallel, delayed
# from functools import partial

# def applyParallel(grouped, func):
#     return pd.concat(Parallel(n_jobs=3)(delayed(func)(group) for name, group in grouped))

# out = applyParallel(grouped, partial(extract_sequence, dfs=dfs))

In [52]:
out = grouped.apply(extract_sequence).dropna()
out.to_csv('../tmp/raw_sequences.csv')

In [53]:
out.head()

Unnamed: 0_level_0,avg_year,last_pinstate,seq
PIN,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
26225,2011.575758,2.0,G1T0G0P0H0P0P0T0G1E3T0G1T0G1H0P0P0P0P0P1E0G0T2...
568483,2008.687117,1.0,T0G1G0T0T0T0G0G1L0L0L0G0G0T0T1L0L0L0L0L0G0T0L0...
911858,2012.685535,5.0,G0T0O1T0G1O2T0G0P0P0P0P0P0P0E0T0G0I0I0T0G2T0G0...
923748,2011.246649,2.0,P0P0P0L0E0I2T0G1L2T0T0G0S0H0P0P0P0T0G0S0T1L2D1...
950965,2012.078014,2.0,T0G1G0T3G0T1P0H0P0P0P0T0G1T0G1T0G1G0T2G0T1T0G0...


In [31]:
# small_idx = set([1866585844, 5065251748, 7806690482])
# small_mbs_df = mbs_df.loc[mbs_df['PIN'].isin(small_idx), :]
# small_mbs_df.shape

(1041, 5)

In [34]:
# small_grouped = small_mbs_df.groupby('PIN').filter(lambda x: len(x)>1).groupby('PIN')

In [45]:
# small_out = small_grouped.apply(extract_sequence).dropna()
# small_out

Unnamed: 0_level_0,avg_year,last_pinstate,seq
PIN,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1866585844,2011.7109,2,P0G0P0H0P1P0P0E3D2P0G0P0H0P3P0H0P0P0G2G2G2P0P0...
5065251748,2012.861538,4,T0G1G0T0T0G0L0T0G2G0T2P0P0P0P0P0P0H0P0P0P0E0G0...
7806690482,2011.618497,1,P0P0P0P0G0T1E0E1T0I0E1G0T0P0P0P0P0T0G0E3P0T0G0...
