In [1]:
from logging import getLogger
from recbole.utils import init_logger, init_seed
from recbole.trainer import Trainer
from custom.light_gcn import LightGCN
from custom.custom_model_gcn import Doctr
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
import pickle
from recbole.utils.case_study import full_sort_scores, full_sort_topk
import torch

from utils import com_sim, max_sim
from genetic import genetic_algorithm, calc_metrics
import numpy as np
from tqdm import tqdm

In [None]:
config = Config(model=Doctr, dataset='trial_zero', config_file_list=["./atomic/doctr.yaml"])
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config)
logger = getLogger()

logger.info(config)

# dataset filtering
dataset = create_dataset(config)
#dataset = pickle.load(open('./dataset.pkl', 'rb'))
logger.info(dataset)
print('-------------')

# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)

In [3]:
model = Doctr(config, train_data.dataset).to(config['device'])
trainer = Trainer(config, model)
trainer.eval_collector.data_collect(train_data)
trainer.eval_collector.model_collect(model)

In [None]:
test_result = trainer.evaluate(test_data, model_file='./models/doctr.pth')

14 May 18:24    INFO  Loading model structure and parameters from ../models/doctr.pth


In [None]:
test_trials = pickle.load(open('./data/zero_shot/test_list.pkl', 'rb'))
trial2idx = pickle.load(open('./data/trial2idx.pkl', 'rb'))
npi2idx = pickle.load(open('./data/npi2idx.pkl', 'rb'))
idx2npi = {v: k for k, v in npi2idx.items()}
trial2npi = pickle.load(open('./data/trial2npi.pkl', 'rb'))
npi_info = pickle.load(open('./data/npi_info_dict.pkl', 'rb'))
trial2category = pickle.load(open('./data/trial2category.pkl', 'rb'))
trial2phase = pickle.load(open('./data/trial_phase.pkl', 'rb'))

zip2fips = pickle.load(open('./data/zip2fips.pkl', 'rb'))
fips2demo = pickle.load(open('./data/fips2demo.pkl', 'rb'))
fips2state = pickle.load(open('./data/fips2state.pkl', 'rb'))

competing_dict = pickle.load(open('./data/zero_shot/competing_dict.pkl', 'rb'))

In [11]:
def get_embd(trial_name):
    trial_idx = [str(trial2idx[trial_name])]
    trial_series = dataset.token2id(dataset.uid_field, trial_idx)
    scores = full_sort_scores(trial_series, model, test_data, device=config['device'])
    topk_scores, topk_indices = full_sort_topk(trial_series, model, test_data, k=100, device=config['device'])
    topk_scores = topk_scores.cpu().numpy()
    topk_indices = topk_indices.cpu().numpy()
    y_rec_embd = model.item_embedding_global(torch.tensor(topk_indices).to(config['device']))
    
    gt_npi = trial2npi[trial_name]
    if len(gt_npi) == 0:
        return None, None
    gt_npi_idx = []
    for npi in gt_npi:
        gt_npi_idx.append(str(npi2idx[npi]))
    gt_npi_idx = dataset.token2id(dataset.iid_field, gt_npi_idx)
    gt_npi_embd = model.item_embedding_global(torch.tensor(gt_npi_idx).to(config['device']))
    
    return y_rec_embd.cpu().numpy(), gt_npi_embd.unsqueeze(0).cpu().numpy(), topk_scores, topk_indices

In [34]:
sim_gt = []
sim_5 = []
sim_10 = []
sim_20 = []
score = []
k_indexs = []
k_scores = []
trial_idx = []
for i in tqdm(range(len(test_trials))):
    res = get_embd(test_trials[i])
    if res[0] is None:
        continue
    y_pred, y_true, k_score, k_index = res
    
    score.append(max_sim(y_true, y_pred).squeeze())
    sim_gt.append(com_sim(y_true, y_pred, -1))
    sim_5.append(com_sim(y_true, y_pred, 5))
    sim_10.append(com_sim(y_true, y_pred, 10))
    sim_20.append(com_sim(y_true, y_pred, 20))
    k_indexs.append(k_index.squeeze())
    k_scores.append(k_score.squeeze())
    trial_idx.append(test_trials[i])

# y_pred, y_true = get_embd(test_trials[3][0], 10)
# com_sim(y_true, y_pred)
print(np.mean(sim_gt))
print(np.mean(sim_5))
print(np.mean(sim_10))
print(np.mean(sim_20))
k_indexs = np.array(k_indexs)
k_scores = np.array(k_scores)
score = np.array(score)

100%|██████████| 426/426 [00:03<00:00, 120.34it/s]


0.6010147
0.60231626
0.6037619
0.5981742


In [35]:
phase_res = {}
for phase in range(1, 5):
    phase_res[phase] = []
    for i in range(len(trial_idx)):
        if trial_idx[i] not in trial2phase or trial2phase[trial_idx[i]] == None:
            continue
        if ('Phase ' + str(phase)) in trial2phase[trial_idx[i]]:
            phase_res[phase].append(sim_gt[i])
    print(np.mean(phase_res[phase]))
print(len(phase_res[1]), len(phase_res[2]), len(phase_res[3]), len(phase_res[4]))

0.48865393
0.6300379
0.74652463
0.53898644
98 146 103 29


In [41]:
category_res = {}
for category in ['Oncology', 'Cardiology', 'Neurology', 'Endocrinology', 'Infectious Disease']:
    category_res[category] = []
    for i in range(len(trial_idx)):
        if trial_idx[i] not in trial2category or trial2category[trial_idx[i]] == None:
            continue
        if category in trial2category[trial_idx[i]]:
            category_res[category].append(sim_gt[i])
    print(np.mean(category_res[category]))
print(len(category_res['Oncology']), len(category_res['Cardiology']), len(category_res['Neurology']), len(category_res['Endocrinology']), len(category_res['Infectious Disease']))

0.424258
0.51669663
0.7597169
0.7697163
0.780276
54 35 37 31 30


In [None]:
# Get all the NPIs in the top predictions on the test set
npi_ids = dataset.id2token(dataset.iid_field, k_indexs)
npi_tokens = []
for each_t in npi_ids:
    cur_tokens = []
    for npi_id in each_t:
        cur_tokens.append(idx2npi[int(npi_id)])
    npi_tokens.append(cur_tokens)


In [None]:
genetic_batch = []
for idx, each_trial in enumerate(npi_tokens):
    # Get top 10 NPIs
    # Construct data struct
    # [score, [gender ratio], [race ratio], [ethnicity ratio], [fips code]]
    cur_batch = []
    normed_k_scores = (k_scores[idx] - np.min(k_scores[idx])) / (np.max(k_scores[idx]) - np.min(k_scores[idx]))
    for idx2, each_npi in enumerate(each_trial):
        cur_zip = npi_info[each_npi]['Zip_Code']
        if '-'  in cur_zip:
            cur_zip = cur_zip.split('-')[0]
        cur_zip = int(cur_zip)
        cur_fips = zip2fips.get(cur_zip)
        if cur_fips is not None and cur_fips in fips2demo:
            cur_score = score[idx, idx2]
            cur_gender = [fips2demo[cur_fips]['male'], fips2demo[cur_fips]['female']]
            cur_race = [fips2demo[cur_fips]['white'], fips2demo[cur_fips]['black'], fips2demo[cur_fips]['indian'], fips2demo[cur_fips]['asian'], fips2demo[cur_fips]['native']]
            cur_ethnicity = [fips2demo[cur_fips]['nonhis'], fips2demo[cur_fips]['his']]
            cur_compete = competing_dict[test_trials[i]][each_npi]
            cur_batch.append([cur_score, cur_gender, cur_race, cur_ethnicity, fips2state[cur_fips], cur_compete, idx2])
        else:
            cur_compete = competing_dict[test_trials[i]][each_npi]
            cur_batch.append([normed_k_scores[idx2], fips2demo['national_average']['gender'], fips2demo['national_average']['race'], fips2demo['national_average']['ethnicity'], cur_compete, idx2])
    genetic_batch.append(cur_batch)
        #cur_batch.append([score[idx, idx2], npi_info

In [None]:
total_res = []
for idx in tqdm(range(len(genetic_batch))):
    res = genetic_algorithm(genetic_batch[idx], 10, 50, 10)
    total_res.append(calc_metrics(res, score[idx]))
print(np.mean(total_res, axis=0))
# CS@GT, Gender, Race, Ethnicity, Geo, Competing

100%|██████████| 425/425 [02:53<00:00,  2.45it/s]


[0.59525104 0.99940004 0.47320528 0.74243828 0.66698868 0.03388235]


In [None]:
base_res = []
for idx, each_batch in tqdm(enumerate(genetic_batch)):
    # Random choice 10 NPIs
    random_choice = np.random.choice(len(each_batch), 10, replace=False)
    chosen = [each_batch[i] for i in random_choice]
    base_res.append(calc_metrics(chosen, score[idx]))
print(np.mean(base_res, axis=0))
# CS@GT, Gender, Race, Ethnicity, Geo, Competing

425it [00:00, 16991.02it/s]


[0.56622701 0.99943507 0.42647336 0.59356979 0.67120661 4.68941176]
