## Project: Estimation of accuracy of MOI for MVCs 
Created by: Thomas Hartka, MD, MSDS  
Date created: 12/2/21 
  
This notebook performs imputation on missing data and creates five new data sets.

In [1]:
import numpy as np
import pandas as pd
from itertools import combinations
from sklearn.experimental import enable_iterative_imputer  
from sklearn.impute import IterativeImputer
import math

## Parameters

In [2]:
# number of imputated data sets
num_imp = 5 

In [3]:
unfiltered_data_file = "../Data/NASS_CISS-2000_2019-unfiltered.csv"

In [4]:
predictors = ['age','prop_restraint','any_restraint','abdeply','dvtotal',
              'splimit','multicoll','pdof_nearside','rolled', 'roll_turns',
              'int18','int12occ','ejection','other_death','entrapment']

responses = ['iss16','mais','mais3']

variables = predictors + responses

## Read in data

In [5]:
mvcs = pd.read_csv(unfiltered_data_file)

## Clean up imputated variables

In [6]:
def clean_vars(df, df_imp):
    # loop through all variables
    for var in list(df_imp.columns):
        # if continuous
        if not all([((i in [0,1])  | (math.isnan(i))) for i in df[var].unique()]):

            # all continuous variables should be greater than zero
            df_imp[var] = df_imp.apply(lambda x: x[var] if x[var] >= 0  else 0, axis=1)
            #print(var, " is continuous")

        # else binary
        else:

            # convert binary variables to 0/1 at 0.5 cut off
            df_imp[var] = df_imp.apply(lambda x: 1 if x[var] >= 0.5  else 0, axis=1)
            #print(var, " is binary")
            
    return df_imp

## Function to impute data

In [7]:
def impute_cols(data, imp_variables, seed):
    
    # set up imputer
    imp_mod = IterativeImputer(max_iter=10, random_state=seed)
    imp_mod.fit(mvcs[imp_variables])
    
    # imputate data
    imputed = imp_mod.transform(data[imp_variables])
    
    # convert to pandas df
    data_imp = pd.DataFrame(imputed, columns=imp_variables)
    
    # clean up imputed variables
    data_imp = clean_vars(data, data_imp)
    
    # find columns not in imputed data
    unimp_cols = list(set(data.columns)-set(data_imp.columns))
        
    # add back unimputed columns
    data_imp[unimp_cols] = data[unimp_cols]
    
    return data_imp

In [8]:
%%time
# get first imputated data set
mvcs_imp = impute_cols(mvcs, variables, 42)

# add subsequent data sets
for i in range(1, num_imp):
    print(i)
    mvcs_imp = mvcs_imp.append(impute_cols(mvcs, variables, i))

1
2
3
4
CPU times: user 27min 15s, sys: 18min 40s, total: 45min 55s
Wall time: 6min 48s


In [9]:
len(mvcs)

150683

In [10]:
len(mvcs_imp)

753415

In [12]:
sum(mvcs[mvcs.year<2009].casewgt)

40872468.21332148

In [10]:
len(mvcs[mvcs.year>=2010])

48764

## Output data

In [11]:
mvcs_imp.columns

Index(['age', 'prop_restraint', 'any_restraint', 'abdeply', 'dvtotal',
       'splimit', 'multicoll', 'pdof_nearside', 'rolled', 'roll_turns',
       'int18', 'int12occ', 'ejection', 'other_death', 'entrapment', 'iss16',
       'mais', 'mais3', 'dataset', 'year', 'casewgt', 'died'],
      dtype='object')

In [12]:
mvcs_imp.head(10)

Unnamed: 0,age,prop_restraint,any_restraint,abdeply,dvtotal,splimit,multicoll,pdof_nearside,rolled,roll_turns,...,ejection,other_death,entrapment,iss16,mais,mais3,dataset,year,casewgt,died
0,55.0,0,0,1,25.0,89.0,1,0,0,0.0,...,1,0,0,1,4.0,1,NASS,2000,106.932133,0
1,22.0,1,1,1,39.0,89.0,1,0,0,0.0,...,0,0,1,0,3.0,0,NASS,2000,106.932133,0
2,25.0,0,0,1,39.0,89.0,1,0,0,0.0,...,0,0,0,0,2.0,0,NASS,2000,106.932133,0
3,37.0,1,1,1,26.003873,72.0,1,0,0,0.0,...,0,0,0,0,2.0,0,NASS,2000,3171.822421,0
4,6.0,0,1,1,25.243459,72.0,1,0,0,0.0,...,0,0,0,0,0.0,0,NASS,2000,3171.822421,0
5,20.0,1,1,0,22.520525,113.0,0,0,0,0.0,...,0,0,0,0,0.0,0,NASS,2000,777.280666,0
6,21.0,1,1,0,22.483876,113.0,0,0,0,0.0,...,0,0,0,0,0.0,0,NASS,2000,777.280666,0
7,19.0,1,1,0,23.868089,113.0,0,0,0,0.0,...,0,0,0,0,1.0,0,NASS,2000,777.280666,0
8,20.0,1,1,0,23.448925,113.0,1,0,0,0.0,...,0,0,0,0,1.0,0,NASS,2000,777.280666,0
9,15.0,1,1,0,16.192142,89.0,0,0,1,2.0,...,0,0,0,0,0.0,0,NASS,2000,1256.154247,0


In [13]:
mvcs_imp.to_csv("../Data/NASS_CISS-2000_2019-imputated.csv", index=False)