In [25]:
task = "drugrec"

In [26]:
import pickle

with open(f'/data/pj20/exp_data/ccscm_ccsproc/sample_dataset_mimic3_{task}_th015.pkl', 'rb') as f:
    sample_dataset = pickle.load(f)

In [27]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], train_ratio=1.0, seed=528)

In [28]:
val_dataset = list(val_dataset)

In [29]:
import random
from collections import defaultdict

def nested_dict():
    return defaultdict(list)

def drop_ratio(d, ratio):
    new_d = defaultdict(list)
    for key, value in d.items():
        if key == "conditions" or key == "procedures" or key == "drugs":
            if isinstance(value[0], list):
                for i in range(len(value)):
                        n = int(len(value[i]) * (ratio))  # calculate the number of items to keep
                        sampled = random.sample(value[i], n)  # randomly select items to keep
                        if len(sampled) > 0:
                            new_d[key].append(sampled)
                        else:
                            new_d[key].append(random.sample(value[i], 1))
            else:
                n = int(len(value) * (ratio))
                sampled = random.sample(value, n)
                if len(new_d[key]) == 0:
                    new_d[key] = value
                else:
                    new_d[key] = random.sample(value, 1)
        elif key == "label":
            new_d[key] = value
    return new_d

In [30]:
ratios = [
    0.05,
    0.1,
    0.2,
    0.3,
    0.4,
    0.5,
    0.6,
    0.7,
    0.8,
    0.9,
]

In [31]:
from copy import deepcopy
from data_prepare import *
from tqdm import tqdm

for ratio in ratios:
    print("ratio: ", ratio)
    val_instance = deepcopy(val_dataset)

    for i in range(len(val_instance)):
        val_instance[i] = drop_ratio(val_instance[i], ratio)

    print("Loading embeddings...")
    ent2id, rel2id, ent_emb, rel_emb = load_embeddings(task)
    map_cluster, map_cluster_inv, map_cluster_rel, map_cluster_inv_rel = clustering(task, ent_emb, rel_emb, threshold=0.15, load_cluster=True, save_cluster=False)

    print("Processing graph...")
    G = process_graph("mimic3", task, val_instance, ent2id, rel2id, map_cluster, map_cluster_inv, map_cluster_rel, map_cluster_inv_rel, save_graph=False)
    G_tg = from_networkx(G)
    
    print("Processing dataset...")
    valid_dataset = process_sample_dataset("mimic3", task, val_instance, G_tg, ent2id, rel2id, map_cluster, map_cluster_inv, map_cluster_rel, map_cluster_inv_rel, save_dataset=False)

    with open(f'/data/pj20/exp_data/ccscm_ccsproc/val_dataset_mimic3_{task}_th015_{ratio}.pkl', 'wb') as f:
        pickle.dump(valid_dataset, f)

    print(f"Done ratio: {ratio}")

ratio:  0.05
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:02<00:00, 2076.43it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:02<00:00, 1488.55it/s]


Done ratio: 0.05
ratio:  0.1
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:02<00:00, 1725.56it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:04<00:00, 971.98it/s] 


Done ratio: 0.1
ratio:  0.2
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:03<00:00, 1215.36it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:04<00:00, 1016.33it/s]


Done ratio: 0.2
ratio:  0.3
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:05<00:00, 801.39it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:05<00:00, 741.75it/s]


Done ratio: 0.3
ratio:  0.4
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:07<00:00, 577.82it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:07<00:00, 545.65it/s]


Done ratio: 0.4
ratio:  0.5
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:08<00:00, 507.21it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:09<00:00, 453.82it/s]


Done ratio: 0.5
ratio:  0.6
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:10<00:00, 401.57it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:11<00:00, 374.78it/s]


Done ratio: 0.6
ratio:  0.7
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:11<00:00, 385.35it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:12<00:00, 355.61it/s]


Done ratio: 0.7
ratio:  0.8
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:12<00:00, 338.88it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:13<00:00, 315.94it/s]


Done ratio: 0.8
ratio:  0.9
Loading embeddings...
Processing graph...


100%|██████████| 4346/4346 [00:15<00:00, 272.80it/s]


Processing dataset...


100%|██████████| 4346/4346 [00:15<00:00, 279.52it/s]


Done ratio: 0.9


In [32]:
l = get_dataloader(val_instance,batch_size=1, shuffle=False)

for data in l:
    print(data)
    break

{'conditions': [[['3', '55', '80', '131', '157', '135', '159']]], 'procedures': [[['216']]], 'drugs': [['N01A', 'C07A', 'N06B', 'C08D', 'J01D', 'V04C', 'N02B', 'N05A', 'B05X', 'N05B', 'A02B', 'C10A', 'M03B', 'A12B', 'B05B', 'A12C', 'N02A', 'A03F', 'C09A', 'B01A', 'G04B', 'N03A', 'A02A', 'A06A', 'J01C', 'P01A', 'M01A', 'A07A']], 'node_set': [[1, 1028, 1033, 1035, 12, 1038, 2066, 22, 23, 27, 2076, 29, 1053, 2082, 1059, 38, 1062, 1064, 41, 42, 1070, 1071, 2101, 2102, 55, 56, 2105, 58, 59, 62, 64, 65, 2114, 69, 72, 78, 81, 85, 2134, 2135, 91, 92, 2140, 1116, 95, 96, 98, 99, 2148, 104, 106, 1131, 113, 1137, 1140, 2167, 1149, 2174, 128, 1155, 134, 135, 2185, 138, 139, 141, 1166, 1167, 145, 1170, 2195, 1176, 1178, 1182, 161, 162, 165, 170, 1194, 172, 176, 177, 1208, 1209, 189, 191, 195, 2244, 197, 1221, 2249, 201, 1225, 1229, 206, 1230, 2258, 2265, 1241, 2267, 220, 221, 1246, 223, 233, 234, 1258, 237, 1265, 256, 1284, 1286, 1288, 266, 269, 270, 271, 272, 2321, 277, 281, 1306, 2334, 287, 2336,