In [4]:
from pytrial.tasks.trial_simulation.sequence import PromptEHR
from pytrial.data.demo_data import load_mimic_ehr_sequence
from pytrial.tasks.trial_simulation.data import SequencePatient

full_model_path = './model_50_epochs_30k_samples'
partial_model_path = './model_20_epochs_15k_samples/'


def load_model(path) -> PromptEHR:
    model = PromptEHR()
    model.from_pretrained(path)
    return model


def create_seq_pat(data) -> SequencePatient:
    return SequencePatient(
        data={
            'v': data['visit'],
            'y': data['y'],
            'x': data['feature'],
        },
        metadata={
            'visit': {'mode': 'dense'},
            'label': {'mode': 'tensor'},
            'voc': data['voc'],
            'max_visit': 20,
            'n_num_feature': data['n_num_feature'],
            'cat_cardinalities': data['cat_cardinalities'],
        }
    )


def load_seq_pat(path):
    raw_data = load_mimic_ehr_sequence(input_dir=path)
    return create_seq_pat(raw_data)

In [5]:
from pytrial.data.demo_data import load_synthetic_ehr_sequence
raw_synth_data = load_synthetic_ehr_sequence()
synth_data = create_seq_pat(raw_synth_data)

In [11]:
# Evaluate the different models on the synthetic data
pretrained_model_path = None
for path in [full_model_path, partial_model_path, pretrained_model_path]:
    if path is not None:
        model = load_model(path)
    else:
        path = 'pretrained_model'
        model = PromptEHR()
        model.from_pretrained()
    model.evaluate(synth_data)
    del model
    

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.


Load pretrained PromptEHR model from ./model_50_epochs_30k_samples
Load the pre-trained model from: ./model_50_epochs_30k_samples
evaluation for code diag.
evaluation for tpl perplexity.
code: diag, ppl_type: tpl, value: 2612.161376953125
evaluation for code diag.
evaluation for spl perplexity.
code: diag, ppl_type: spl, value: 3272.134521484375
evaluation for code prod.
evaluation for tpl perplexity.
code: prod, ppl_type: tpl, value: 1671.0849609375
evaluation for code prod.
evaluation for spl perplexity.
code: prod, ppl_type: spl, value: 2932.37158203125
evaluation for code med.
evaluation for tpl perplexity.
code: med, ppl_type: tpl, value: 358.09295654296875
evaluation for code med.
evaluation for spl perplexity.
code: med, ppl_type: spl, value: 302.63525390625


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.


Load pretrained PromptEHR model from ./model_20_epochs_15k_samples/
Load the pre-trained model from: ./model_20_epochs_15k_samples/
evaluation for code diag.
evaluation for tpl perplexity.
code: diag, ppl_type: tpl, value: 2881.69677734375
evaluation for code diag.
evaluation for spl perplexity.
code: diag, ppl_type: spl, value: 4073.4453125
evaluation for code prod.
evaluation for tpl perplexity.
code: prod, ppl_type: tpl, value: 1916.8770751953125
evaluation for code prod.
evaluation for spl perplexity.
code: prod, ppl_type: spl, value: 2720.56689453125
evaluation for code med.
evaluation for tpl perplexity.
code: med, ppl_type: tpl, value: 178.76751708984375
evaluation for code med.
evaluation for spl perplexity.
code: med, ppl_type: spl, value: 212.9291534423828


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.


Load pretrained PromptEHR model from ./simulation/pretrained_promptEHR
Load the pre-trained model from: ./simulation/pretrained_promptEHR
evaluation for code diag.
evaluation for tpl perplexity.
code: diag, ppl_type: tpl, value: 323.42999267578125
evaluation for code diag.
evaluation for spl perplexity.
code: diag, ppl_type: spl, value: 501.9122314453125
evaluation for code prod.
evaluation for tpl perplexity.
code: prod, ppl_type: tpl, value: 229.15621948242188
evaluation for code prod.
evaluation for spl perplexity.
code: prod, ppl_type: spl, value: 165.3402557373047
evaluation for code med.
evaluation for tpl perplexity.
code: med, ppl_type: tpl, value: 89.43730163574219
evaluation for code med.
evaluation for spl perplexity.
code: med, ppl_type: spl, value: 64.30888366699219


In [12]:
# Generate synthetic data for each model and evalutate on privacy
from pytrial.tasks.trial_simulation.sequence.evaluation import RNNPrivacyDetection, RNNUtilityDetection

# Full model
full_model = load_model(full_model_path)
synthetic_seqdata = full_model.predict(synth_data, n=20, n_per_sample=1, verbose=False)
del full_model
# Evaluate privacy
print(RNNPrivacyDetection.compute(synth_data, synthetic_seqdata, device='cuda:0'))
# Evaluate utility
print(RNNUtilityDetection.compute(synth_data, synthetic_seqdata, device='cuda:0'))
del synthetic_seqdata

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.


Load pretrained PromptEHR model from ./model_50_epochs_30k_samples
Load the pre-trained model from: ./model_50_epochs_30k_samples
522 reach model max length 512, do cut.
512 reach model max length 512, do cut.
523 reach model max length 512, do cut.
529 reach model max length 512, do cut.
516 reach model max length 512, do cut.
523 reach model max length 512, do cut.
529 reach model max length 512, do cut.
519 reach model max length 512, do cut.
522 reach model max length 512, do cut.
519 reach model max length 512, do cut.
514 reach model max length 512, do cut.
519 reach model max length 512, do cut.
528 reach model max length 512, do cut.
{'lr': 0.0001, 'weight_decay': 0.0001}
***** Running training *****
  Num examples = 816
  Num Epochs = 10
  Total optimization steps = 130


Iteration: 100%|██████████| 13/13 [00:00<00:00, 135.39it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 133.60it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 136.81it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 141.85it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 141.27it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 144.41it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 135.38it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 81.74it/s]
Training Epoch:  80%|████████  | 8/10 [00:00<00:00,  9.28it/s]


######### Train Loss 100 #########
0 0.5647 


######### Eval 100 #########
auc: 0.6169


Iteration: 100%|██████████| 13/13 [00:00<00:00, 136.81it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 139.75it/s]
Training Epoch: 100%|██████████| 10/10 [00:01<00:00,  9.83it/s]


Load best ckpt from `./checkpoints/best`.
Training completes.
0.6169154228855721
{'lr': 0.0005, 'weight_decay': 0}
***** Running training *****
  Num examples = 700
  Num Epochs = 50
  Total optimization steps = 300


Iteration: 100%|██████████| 6/6 [00:00<00:00, 48.62it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.03it/s]7it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.81it/s]5it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.80it/s]4it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]7it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 59.39it/s]3it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.56it/s]5it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.28it/s]5it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.03it/s]0it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 50.41it/s]7it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.54it/s]83it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]93it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.23it/s]93it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.53it/s]96it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.68it/s]95it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.62it/s]


######### Train Loss 100 #########
0 0.5160 


######### Eval 100 #########
auc: 0.5785


Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.75it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 59.39it/s]75it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.59it/s]27it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]56it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.53it/s]85it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.03it/s]88it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.56it/s]94it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.25it/s]91it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 51.71it/s]02it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.03it/s]85it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.13it/s]90it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.06it/s]05it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.24it/s]11it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.53it/s]09it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]04it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 5


######### Train Loss 200 #########
0 0.2047 


######### Eval 200 #########
auc: 0.6478


Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.86it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.06it/s]62it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.22it/s]05it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.56it/s]32it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 59.39it/s]47it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.53it/s]83it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]88it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 59.39it/s]90it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.59it/s]15it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.56it/s]21it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.06it/s]10it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.03it/s]12it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.75it/s]08it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]12it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 51.71it/s]04it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 2


######### Train Loss 300 #########
0 0.0172 


######### Eval 300 #########
auc: 0.6531
Load best ckpt from `./checkpoints/best`.
Training completes.
{'lr': 0.0005, 'weight_decay': 0}
***** Running training *****
  Num examples = 14
  Num Epochs = 50
  Total optimization steps = 50


Iteration: 100%|██████████| 1/1 [00:00<00:00, 49.99it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 20.83it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 29.41it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 31.24it/s]0it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 33.33it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 49.99it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 15.38it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 83.31it/s]4it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 19.75it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 124.97it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 199.98it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 166.61it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 166.63it/s]5it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 142.83it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 142.83it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 142.81it/s]
Iterat

Training completes.
{'real-data-model-auc': 0.6531165311653117, 'syn-data-model-auc': 0.5301866907557965}


In [13]:
# Partial model
partial_model = load_model(partial_model_path)
synthetic_seqdata = partial_model.predict(synth_data, n=20, n_per_sample=1, verbose=False)
del partial_model
# Evaluate privacy
print(RNNPrivacyDetection.compute(synth_data, synthetic_seqdata, device='cuda:0'))
# Evaluate utility
print(RNNUtilityDetection.compute(synth_data, synthetic_seqdata, device='cuda:0'))
del synthetic_seqdata

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.


Load pretrained PromptEHR model from ./model_20_epochs_15k_samples/
Load the pre-trained model from: ./model_20_epochs_15k_samples/
522 reach model max length 512, do cut.
521 reach model max length 512, do cut.
531 reach model max length 512, do cut.
514 reach model max length 512, do cut.
514 reach model max length 512, do cut.
514 reach model max length 512, do cut.
512 reach model max length 512, do cut.
532 reach model max length 512, do cut.
512 reach model max length 512, do cut.
520 reach model max length 512, do cut.
516 reach model max length 512, do cut.
519 reach model max length 512, do cut.
523 reach model max length 512, do cut.
513 reach model max length 512, do cut.
515 reach model max length 512, do cut.
{'lr': 0.0001, 'weight_decay': 0.0001}
***** Running training *****
  Num examples = 816
  Num Epochs = 10
  Total optimization steps = 130


Iteration: 100%|██████████| 13/13 [00:00<00:00, 142.82it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 141.27it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 144.41it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 146.03it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 142.82it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 133.99it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 138.27it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 83.31it/s]
Training Epoch:  80%|████████  | 8/10 [00:00<00:00,  9.41it/s]


######### Train Loss 100 #########
0 0.6480 


######### Eval 100 #########
auc: 0.4387


Iteration: 100%|██████████| 13/13 [00:00<00:00, 142.82it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 131.06it/s]
Training Epoch: 100%|██████████| 10/10 [00:01<00:00,  9.99it/s]


Load best ckpt from `./checkpoints/best`.
Training completes.
0.43875
{'lr': 0.0005, 'weight_decay': 0}
***** Running training *****
  Num examples = 700
  Num Epochs = 50
  Total optimization steps = 300


Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.08it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.24it/s]7it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]8it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.16it/s]0it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 48.77it/s]9it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]1it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 51.71it/s]5it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.34it/s]1it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 53.08it/s]7it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.61it/s]1it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 48.77it/s]78it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.53it/s]50it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 49.99it/s]65it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.68it/s]53it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.16it/s]78it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.13it/s]


######### Train Loss 100 #########
0 0.4961 


######### Eval 100 #########
auc: 0.6567


Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.13it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 50.41it/s]76it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]92it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 49.58it/s]20it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.16it/s]18it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 48.38it/s]30it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 49.99it/s]19it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.16it/s]21it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 49.99it/s]33it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.16it/s]31it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 55.03it/s]38it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 51.27it/s]58it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 50.41it/s]55it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 52.62it/s]46it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 46.50it/s]53it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 4


######### Train Loss 200 #########
0 0.1859 


######### Eval 200 #########
auc: 0.6557



Iteration: 100%|██████████| 6/6 [00:00<00:00, 44.43it/s]94it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 43.79it/s]29it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 45.44it/s]52it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 47.61it/s]79it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 50.84it/s]05it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 51.27it/s]41it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 47.39it/s]70it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 45.79it/s]72it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 48.53it/s]66it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 50.41it/s]75it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 47.99it/s]91it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 50.58it/s]92it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 47.23it/s]03it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 49.17it/s]96it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 51.71it/s]99it/s]
Iteration: 100%|██████████| 6/6 [00:00<


######### Train Loss 300 #########
0 0.0149 


######### Eval 300 #########
auc: 0.6364
Load best ckpt from `./checkpoints/best`.
Training completes.
{'lr': 0.0005, 'weight_decay': 0}
***** Running training *****
  Num examples = 14
  Num Epochs = 50
  Total optimization steps = 50


Iteration: 100%|██████████| 1/1 [00:00<00:00, 71.41it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 55.54it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.91it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 71.41it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 62.49it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 58.81it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 55.54it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 71.41it/s]2it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 35.71it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 47.61it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 71.41it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 71.41it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 38.45it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 37.03it/s]67it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 83.32it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 83.31it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.91it/s]
Iteration: 100%|██

Training completes.
{'real-data-model-auc': 0.6567497966928707, 'syn-data-model-auc': 0.40329357549471406}


In [14]:
# Pretrained model
pre_model = PromptEHR()
pre_model.from_pretrained()
synthetic_seqdata = pre_model.predict(synth_data, n=20, n_per_sample=1, verbose=False)
del pre_model
# Evaluate privacy
print(RNNPrivacyDetection.compute(synth_data, synthetic_seqdata, device='cuda:0'))
# Evaluate utility
print(RNNUtilityDetection.compute(synth_data, synthetic_seqdata, device='cuda:0'))
del synthetic_seqdata

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.


Load pretrained PromptEHR model from ./simulation/pretrained_promptEHR
Load the pre-trained model from: ./simulation/pretrained_promptEHR
514 reach model max length 512, do cut.
516 reach model max length 512, do cut.
514 reach model max length 512, do cut.
519 reach model max length 512, do cut.
515 reach model max length 512, do cut.
519 reach model max length 512, do cut.
517 reach model max length 512, do cut.
514 reach model max length 512, do cut.
512 reach model max length 512, do cut.
518 reach model max length 512, do cut.
520 reach model max length 512, do cut.
519 reach model max length 512, do cut.
513 reach model max length 512, do cut.
516 reach model max length 512, do cut.
{'lr': 0.0001, 'weight_decay': 0.0001}
***** Running training *****
  Num examples = 816
  Num Epochs = 10
  Total optimization steps = 130


Iteration: 100%|██████████| 13/13 [00:00<00:00, 151.13it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 154.73it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 146.03it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 152.91it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 154.72it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 152.91it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 156.59it/s]/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 81.74it/s]
Training Epoch:  80%|████████  | 8/10 [00:00<00:00,  9.93it/s]


######### Train Loss 100 #########
0 0.6069 


######### Eval 100 #########
auc: 0.5337


Iteration: 100%|██████████| 13/13 [00:00<00:00, 154.73it/s]
Iteration: 100%|██████████| 13/13 [00:00<00:00, 144.41it/s]
Training Epoch: 100%|██████████| 10/10 [00:00<00:00, 10.60it/s]


Load best ckpt from `./checkpoints/best`.
Training completes.
0.5336700336700337
{'lr': 0.0005, 'weight_decay': 0}
***** Running training *****
  Num examples = 700
  Num Epochs = 50
  Total optimization steps = 300


Iteration: 100%|██████████| 6/6 [00:00<00:00, 60.23it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]4it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 54.04it/s]1it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 61.21it/s]5it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 60.59it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.13it/s]6it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 60.59it/s]9it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.68it/s]8it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.14it/s]3it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.14it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.13it/s]95it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.14it/s]82it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 62.75it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 65.20it/s]02it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 61.84it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 62.48it/s]16it/s]
Iteration: 100%|██


######### Train Loss 100 #########
0 0.5025 


######### Eval 100 #########
auc: 0.5895


Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.14it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.68it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 59.39it/s]89it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.80it/s]05it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 56.59it/s]19it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 66.93it/s]20it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 64.50it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 59.39it/s]75it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.68it/s]74it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]69it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 62.49it/s]67it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 57.34it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 61.21it/s]73it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 60.59it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 60.59it/s]83it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 68.17it/s]87it/s]
Iteration: 


######### Train Loss 200 #########
0 0.1957 


######### Eval 200 #########
auc: 0.6139


Iteration: 100%|██████████| 6/6 [00:00<00:00, 65.49it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]11it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 61.21it/s]21it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.81it/s]37it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 61.21it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.14it/s]70it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 65.92it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 65.20it/s]99it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 63.14it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.81it/s]17it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 64.50it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 66.64it/s]13it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 58.24it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 66.21it/s]15it/s]
Iteration: 100%|██████████| 6/6 [00:00<00:00, 29.26it/s]
Training Epoch: 100%|██████████|


######### Train Loss 300 #########
0 0.0170 


######### Eval 300 #########
auc: 0.6070
Load best ckpt from `./checkpoints/best`.
Training completes.
{'lr': 0.0005, 'weight_decay': 0}
***** Running training *****
  Num examples = 14
  Num Epochs = 50
  Total optimization steps = 50


Iteration: 100%|██████████| 1/1 [00:00<00:00, 55.54it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.91it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 83.32it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.91it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 62.49it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 83.31it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]5it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 58.81it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 83.31it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 66.65it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]69it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 76.90it/s]
Iteration: 100%|██████████| 1/1 [00:00<00:00, 90.89it/s]
Iteration: 100%|██

Training completes.
{'real-data-model-auc': 0.6138519924098672, 'syn-data-model-auc': 0.41257793439956625}
