
# PTSD Model Inference with IRT Features

## [Center for Health Statistics](http://www.healthstats.org)

## [The Zero Knowledge Discovery Lab](http://zed.uchicago.edu)
---



In [1]:
import ccx as cx
import pylab as plt
plt.style.use('ggplot')
import pickle
import pandas as pd
import numpy as np
%matplotlib inline

In [2]:
#datafile='../data/CAD-PTSDData.csv'
nfeatures=211
DEBUG=False

In [3]:
modelset=cx.load('../model/model_3_2.pkl')

In [4]:
def chooseForm(modelset,index=None):
    if index is None:
        index=cx.random.randint(0,len(modelset)-1)
    return modelset[index]

In [5]:
model=chooseForm(modelset)

In [6]:
cx.drawTrees(model,PREF='../model/this')

In [7]:
def ask(item,B=0,E=5):
    response=cx.random.randint(B,E)
    return {item:response}

In [8]:
def getNodeid(model,nodeid=None,responses=None,ask=ask,DEBUG=DEBUG):
    if responses is None:
        responses={}
    if nodeid is None:
        nodeid=np.zeros(model.n_estimators).astype(int)
    
    items=[model.estimators_[i].tree_.feature[nodeid[i]] 
              for i in range(model.n_estimators)]
    
    for i in np.arange(len(items)):
        if nodeid[i] < 0:
            continue
        if items[i] not in responses:
            responses.update(ask(items[i]))
        if responses[items[i]] < model.estimators_[i].tree_.threshold[nodeid[i]]:
            nodeid[i]=model.estimators_[i].tree_.children_left[nodeid[i]]
        else:
            nodeid[i]=model.estimators_[i].tree_.children_right[nodeid[i]]
    
    responses.pop(-2,None)
    if DEBUG:
        print(nodeid)
    return nodeid,responses

In [9]:
def runCAD(model,nfeatures):
    nodeid=np.zeros(model.n_estimators).astype(int)
    responses={}
    while not all(nodeid<0):
        [nodeid,responses]=getNodeid(model,nodeid=nodeid,responses=responses)
    
    Xs=np.zeros(nfeatures)
    for key in responses.keys():
        Xs[key]=responses[key]
    prd=model.predict_proba(Xs.reshape(1, -1))

    return responses,prd

In [10]:
for i in range(40):
    [r,p]=runCAD(model,nfeatures)
    print(p)

[[0.53 0.47]]
[[0.75 0.25]]
[[0.44481605 0.55518395]]
[[0.875 0.125]]
[[0.46474359 0.53525641]]
[[0.44481605 0.55518395]]
[[0.75 0.25]]
[[0.78 0.22]]
[[0.94270833 0.05729167]]
[[0.94270833 0.05729167]]
[[1. 0.]]
[[0.30361757 0.69638243]]
[[0.57729469 0.42270531]]
[[0.5625 0.4375]]
[[0.83333333 0.16666667]]
[[0.546875 0.453125]]
[[0.83333333 0.16666667]]
[[0.75 0.25]]
[[0.58139535 0.41860465]]
[[0.57411859 0.42588141]]
[[0.8134058 0.1865942]]
[[0.625 0.375]]
[[0.70833333 0.29166667]]
[[0.546875 0.453125]]
[[0.30361757 0.69638243]]
[[0.5625 0.4375]]
[[0.46474359 0.53525641]]
[[1. 0.]]
[[1. 0.]]
[[0.46474359 0.53525641]]
[[0.83333333 0.16666667]]
[[0.60507246 0.39492754]]
[[0.41757246 0.58242754]]
[[0.75 0.25]]
[[0.50833333 0.49166667]]
[[0.655 0.345]]
[[0.46474359 0.53525641]]
[[0.984375 0.015625]]
[[0.14389535 0.85610465]]
[[0.95833333 0.04166667]]


In [None]:
nfeatures