In [1]:

from quasinet.qnet import load_qnet, save_qnet
from quasinet.qnet import qdistance
from quasinet.qsampling import qsample
from quasinet.qnet import membership_degree
import pandas as pd
import numpy as np
from tqdm import tqdm
from quasinet.qnet import Qnet
import dill as pickle  # Dill functions similarly to pickle
import gzip
import shap
from scipy.stats import t, lognorm
from truthfinder import reveal

# train the veritas model

#datapath=''
#index_present=False
#training_fraction=0.3
#target_label='PTSDDx'
#query_limit=None


class truthnet:
    def __init__(self, datapath,
                 target_label,
                 problem='',
                 index_present=True,
                 target_label_positive=1,
                 target_label_negative=0,
                 training_fraction=0.3,
                 threshold_alpha=0.1,
                 query_limit=None,
                 shap_index=None):
        self.datapath = datapath
        self.target_label = target_label
        self.index_present = index_present
        self.target_label_positive = target_label_positive
        self.target_label_negative = target_label_negative
        self.training_fraction = training_fraction
        self.query_limit = query_limit
        if query_limit is None:
            self.query_limit = -1
        self.shap_index = shap_index
        self.data = None  # Placeholder for the data once loaded
        self.veritas_model={}
        self.data = None
        self.problem = problem
        self.training_index = None
        self.threshold_alpha = threshold_alpha
        
    def fit(self,
            alpha=0.1,
            shap_index=None,
            shapnum=10,
            nullsteps=100000,
            veritas_version='0.0.1'):

        if self.index_present:
            self.data = pd.read_csv(self.datapath,index_col=0)
        else:
            self.data = pd.read_csv(self.datapath)


        num_training=np.rint(self.training_fraction*self.data.index.size).astype(int)
        training_index=np.random.choice(self.data.index.values,num_training, replace=False)

        self.training_index=training_index
        df_training=self.data.loc[training_index,:]
        df_test = self.data.loc[[x for x in self.data.index.values
                                 if x not in training_index],:]
        
        df_training_pos=df_training[df_training[self.target_label]==self.target_label_positive]
        df_training_neg=df_training[df_training[self.target_label]==self.target_label_negative]
        Xpos_training=df_training_pos.drop(self.target_label,
                                           axis=1)\
                                     .values.astype(str)
        Xneg_training=df_training_neg.drop(self.target_label,
                                           axis=1)\
                                     .values.astype(str)

        featurenames = df_training_pos.drop(self.target_label,
                                           axis=1).columns

        modelneg=Qnet(feature_names=featurenames,alpha=alpha)
        modelneg.fit(Xneg_training)
        modelpos=Qnet(feature_names=featurenames,alpha=alpha)
        modelpos.fit(Xpos_training)
        modelneg.training_index=training_index
        modelpos.training_index=training_index

        
        def funcw_(S):
            return np.array([membership_degree(s,modelneg)
                             /membership_degree(s,modelpos) for s in S])

        X=df_test.drop(self.target_label,
                                      axis=1).values.astype(str)
        
        NULLSTR=np.array(['']*len(modelneg.feature_names))
        s_background=qsample(NULLSTR,modelneg,steps=nullsteps)
        explainer = shap.KernelExplainer(funcw_,np.array([s_background]))
        shap_values = explainer.shap_values(X[:shapnum])

        self.shap_index=pd.DataFrame(shap_values.mean(axis=0),
                                     columns=['shap'])\
                          .sort_values('shap',
                                       ascending=False).index.values

        modelneg.shap_index=self.shap_index
        modelpos.shap_index=self.shap_index

        # save veritas model
        self.veritas_model['version']=veritas_version
        self.veritas_model['model']=modelpos
        self.veritas_model['model_neg']=modelneg
        self.veritas_model['problem']=self.problem
        
        return

    
    def calibrate(self,qsteps=1000,calibration_num=10000):

        featurenames=self.veritas_model['model'].feature_names
        NULLSTR=np.array(['']*len(featurenames))
        adict=[]
        for i in tqdm(range(calibration_num)):
            sq=qsample(NULLSTR,self.veritas_model['model'],steps=qsteps)
            ff=pd.DataFrame(sq.reshape(1, -1),
                            columns=featurenames)[featurenames[self.shap_index[:self.query_limit]]]
            adict=np.append(adict,{'xx'+str(i):ff.iloc[0].to_dict()})
        
        return adict


In [2]:
TR=truthnet(datapath='/home/ishanu/ZED/Research/truthnet/notebooks/data/ptsd/PTSD_cognet_test.csv',target_label='PTSDDx',query_limit=20)

In [3]:
TR.fit(shapnum=3)

  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
rf=TR.calibrate(calibration_num=10)

100%|███████████████████████████████████████████| 10/10 [00:01<00:00,  5.74it/s]


In [5]:
from truthfinder import *

In [6]:
json_data=list(rf)
json_data

[{'xx0': {'ptsd89': '4',
   'ptsd11': '4',
   'ptsd142': '5',
   'ptsd108': '4',
   'ptsd95': '3',
   'ptsd146': '5',
   'ptsd72': '5',
   'ptsd40': '4',
   'ptsd34': '4',
   'ptsd180': '4',
   'ptsd169': '4',
   'ptsd194': '3',
   'ptsd106': '3',
   'ptsd140': '3',
   'ptsd191': '3',
   'ptsd16': '5',
   'ptsd84': '4',
   'ptsd122': '2',
   'ptsd193': '4',
   'ptsd139': '4'}},
 {'xx1': {'ptsd89': '2',
   'ptsd11': '3',
   'ptsd142': '4',
   'ptsd108': '3',
   'ptsd95': '4',
   'ptsd146': '3',
   'ptsd72': '5',
   'ptsd40': '4',
   'ptsd34': '4',
   'ptsd180': '3',
   'ptsd169': '3',
   'ptsd194': '5',
   'ptsd106': '4',
   'ptsd140': '5',
   'ptsd191': '2',
   'ptsd16': '1',
   'ptsd84': '4',
   'ptsd122': '2',
   'ptsd193': '5',
   'ptsd139': '4'}},
 {'xx2': {'ptsd89': '4',
   'ptsd11': '3',
   'ptsd142': '1',
   'ptsd108': '2',
   'ptsd95': '4',
   'ptsd146': '3',
   'ptsd72': '2',
   'ptsd40': '4',
   'ptsd34': '5',
   'ptsd180': '3',
   'ptsd169': '4',
   'ptsd194': '4',
   'ptsd1

In [12]:

def revealx(resp_json,
           veritas_model,
           perturb=3):
    list_response_dict = extract_ptsd_items(resp_json)
    
    message=[]
    for i in list_response_dict:
        subjectid=i['subject_id']
        resp = i['responses']
        s=pd.concat([pd.DataFrame(columns=veritas_model['model'].feature_names),
                   pd.DataFrame(resp,index=['response'])])\
                        .fillna('').values[0].astype(str)

        if perturb > 0:
            s=qsample(s,veritas_model['model'],steps=perturb)
            
        i['veritas'] = dissonance_distr_median(s,veritas_model['model'])
        #i['veritas_prob'] = veritas_model['t_distribution'].cdf(i['veritas'])
        i['score']=funcw(s,
                         veritas_model['model'],
                         veritas_model['model_neg'])
        i['lower_threshold']=funcm(s,
                         veritas_model['model'],
                         veritas_model['model_neg'])
        message = message + [interpret(i)]
    return list_response_dict, message

In [8]:
make_str_format(rf)

In [9]:
json_data=list(rf)
isinstance(json_data, list)

True

In [13]:
revealx(json_data,TR.veritas_model)

([{'subject_id': 'xx0',
   'responses': {'ptsd89': '4',
    'ptsd11': '4',
    'ptsd142': '5',
    'ptsd108': '4',
    'ptsd95': '3',
    'ptsd146': '5',
    'ptsd72': '5',
    'ptsd40': '4',
    'ptsd34': '4',
    'ptsd180': '4',
    'ptsd169': '4',
    'ptsd194': '3',
    'ptsd106': '3',
    'ptsd140': '3',
    'ptsd191': '3',
    'ptsd16': '5',
    'ptsd84': '4',
    'ptsd122': '2',
    'ptsd193': '4',
    'ptsd139': '4'},
   'veritas': 0.6799999999999999,
   'score': 1.7887511258067579,
   'lower_threshold': 0.9802158064993378},
  {'subject_id': 'xx1',
   'responses': {'ptsd89': '2',
    'ptsd11': '3',
    'ptsd142': '4',
    'ptsd108': '3',
    'ptsd95': '4',
    'ptsd146': '3',
    'ptsd72': '5',
    'ptsd40': '4',
    'ptsd34': '4',
    'ptsd180': '3',
    'ptsd169': '3',
    'ptsd194': '5',
    'ptsd106': '4',
    'ptsd140': '5',
    'ptsd191': '2',
    'ptsd16': '1',
    'ptsd84': '4',
    'ptsd122': '2',
    'ptsd193': '5',
    'ptsd139': '4'},
   'veritas': 0.681818181818181