In [1]:
import numpy as np 
import sys
import os
sys.path.append("../")

In [2]:
from pyhealth.datasets import MIMIC3BaseDataset, MIMIC4BaseDataset, eICUBaseDataset, OMOPBaseDataset
base_dataset = MIMIC3BaseDataset(root="../../../../srv/local/data/physionet.org/files/mimiciii/1.4")
# base_dataset = eICUBaseDataset(root="/srv/local/data/physionet.org/files/eicu-crd/2.0")
# base_dataset = MIMIC4BaseDataset(root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp")
# base_dataset = OMOPBaseDataset(root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2")
base_dataset.info()

  from .autonotebook import tqdm as notebook_tqdm



        ----- Output Data Structure -----
        Dataset.patients: [
            {
                patient_id: patient_id, 
                visits: [
                    {
                        visit_id: visit_id, 
                        patient_id: patient_id, 
                        conditions: [List], 
                        procedures: [List],
                        drugs: [List],
                        visit_info: <dict>
                    }
                    ...
                ]                    
            } 
            ...
        ]
        


In [3]:
from pyhealth.tasks import DrugRecDataset
drug_rec_dataset = DrugRecDataset(base_dataset)
drug_rec_dataset.info()


        ----- Output Data Structure -----
        >> drug_rec_dataloader[0]
        >> {
            "conditions": List[tensor],
            "procedures": List[tensor],
            "drugs": List[tensor]
        }
        


In [5]:
import numpy as np
from torch.utils.data.dataset import Dataset
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch
from tqdm import tqdm

class CNNDrugRecDataSet(Dataset):
    def __init__(self, dataset, voc_size, max_visits):

        condition_voc, procedure_voc, drug_voc = voc_size[0], voc_size[1], voc_size[2]
        features = condition_voc + procedure_voc
        print('Features are in shapes of ', max_visits, '*', features)
        
        self.transform = transforms.Compose([
                            transforms.Resize((512, 512)),
                            transforms.Normalize(0.00030047886384355915,
                                                    0.017331721677190905)
                        ])

        input_data_list = []
        label_list = []
        for i in tqdm(range(len(dataset))):
            condition_procedure = np.zeros((max_visits, condition_voc+procedure_voc))
            drug_multi_hot = np.zeros(drug_voc)
            conditions_ = dataset[i]['conditions']
            procedures_ = dataset[i]['procedures']
            drugs_ = dataset[i]['drugs']

            # inputs
            for j in range(len(conditions_)):
                condition = conditions_[j]
                for k in range(len(condition)):
                    condition_procedure[j][condition[k]] = 1
            for m in range(len(procedures_)):
                procedure = procedures_[m]
                for n in range(len(procedure)):
                    condition_procedure[m][procedure[n]+condition_voc] = 1
            input_data_list.append(condition_procedure)

            # labels
            for p in range(len(drugs_)):
                drug = drugs_[p]
                for q in range(len(drug)):
                    drug_multi_hot[drug[q]] = 1
            label_list.append(drug_multi_hot)
            
        self.inputs = np.array(input_data_list, dtype=int)
        self.labels = np.array(label_list, dtype=float)

    def __getitem__(self, patient):
        x = self.inputs[patient]
        y = self.labels[patient]
        l = []
        for i in range(3):
            l.append(x)
        l = np.array(l)
        c = torch.from_numpy(l).float()
        x = self.transform(c)
        return x, y

    def __len__(self):
        return len(self.inputs)
        
        

In [6]:
class Resnext50(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        resnet = torchvision.models.resnext50_32x4d(pretrained=False)
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=n_classes)
        )
        self.base_model = resnet
        self.sigm = nn.Sigmoid()

    def forward(self, x):
        return self.sigm(self.base_model(x))

In [7]:
from sklearn.metrics import precision_score, recall_score, f1_score
def calculate_metrics(pred, target, threshold=0.5):
    pred = np.array(pred > threshold, dtype=float)
    return {'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'),
            'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'),
            'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'),
            'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'),
            'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'),
            'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'),
            'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples'),
            'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples'),
            'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples'),
            }

In [8]:
# Initialize the training parameters.
num_workers = 0 # Number of CPU processes for data preprocessing
lr = 1e-4 # Learning rate
batch_size = 32
save_freq = 5 # Save checkpoint frequency (epochs)
test_freq = 200 # Test model frequency (iterations)
max_epoch_number = 35 # Number of epochs for training 
# Note: on the small subset of data overfitting happens after 30-35 epochs

device = torch.device('cuda')
# Save path for checkpoints
save_path = 'cnn_ckpt/'
# Save path for logs
logdir = 'logs/cnn_logs/'

In [9]:
def checkpoint_save(model, save_path, epoch):
    f = os.path.join(save_path, 'checkpoint-{:06d}.pth'.format(epoch))
    if 'module' in dir(model):
        torch.save(model.module.state_dict(), f)
    else:
        torch.save(model.state_dict(), f)
    print('saved checkpoint:', f)

In [10]:
split_ratio = 0.9
idx = int(len(drug_rec_dataset)*0.9)
train_dataset = []
test_dataset = []
for i in range(len(drug_rec_dataset)):
    item = drug_rec_dataset.__getitem__(i)
    if i <= idx:
        train_dataset.append(item)
    else:
        test_dataset.append(item)
voc_size = drug_rec_dataset.voc_size
max_visits = 0
for patient in range(len(drug_rec_dataset)):
    length = len(drug_rec_dataset[patient]['conditions'])
    if length > max_visits:
        max_visits = length


In [11]:
cnn_train = CNNDrugRecDataSet(train_dataset, voc_size, max_visits)
cnn_test = CNNDrugRecDataSet(test_dataset, voc_size, max_visits)

Features are in shapes of  29 * 5907


100%|█████████████████████████████████████████████████████████████████████████████| 4905/4905 [00:03<00:00, 1513.04it/s]


Features are in shapes of  29 * 5907


100%|███████████████████████████████████████████████████████████████████████████████| 544/544 [00:00<00:00, 1036.31it/s]


In [12]:
from torch.utils.data.dataloader import DataLoader

train_dataloader = DataLoader(cnn_train, batch_size=batch_size, shuffle=False, drop_last=True)
test_dataloader = DataLoader(cnn_test, batch_size=batch_size, drop_last=True)

In [13]:
num_train_batches = int(np.ceil(len(train_dataset) / batch_size))

In [14]:
len(cnn_train[0][0][0][0])

512

In [15]:
len(cnn_train[0][1])

3687

In [16]:
model = Resnext50(len(cnn_train[0][1]))



In [17]:
model.train()
model.to(device)

Resnext50(
  (base_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): 

In [18]:
os.makedirs(save_path, exist_ok=True)

In [19]:
from torch.utils.tensorboard import SummaryWriter
logger = SummaryWriter(logdir)

In [20]:
def prepare_gpu(n_gpu_use):
    n_gpu = torch.cuda.device_count()
    print('Num of available GPUs: ', n_gpu)
    if n_gpu_use > 0 and n_gpu == 0:
        n_gpu_use = 0
    if n_gpu_use > n_gpu:
        n_gpu_use = n_gpu
    device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
    list_ids = list(range(n_gpu_use))
    return device, list_ids

In [21]:
# Run training
import warnings
warnings.filterwarnings('ignore')

epoch = 0
iteration = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

criterion = nn.BCELoss()

device, device_ids = prepare_gpu(n_gpu_use=8)
if len(device_ids) > 1:
    model = torch.nn.DataParallel(model, device_ids=device_ids)
    model.cuda()
model.to(f'cuda:{model.device_ids[0]}')


while True:
    batch_losses = []
    for inputs, targets in tqdm(train_dataloader):
        inputs, targets = inputs.to(f'cuda:{model.device_ids[0]}'), targets.to(f'cuda:{model.device_ids[0]}')

        optimizer.zero_grad()

        model_result = model(inputs)
        loss = criterion(model_result, targets.type(torch.float))

        batch_loss_value = loss.item()
        loss.backward()
        optimizer.step()

        logger.add_scalar('train_loss', batch_loss_value, iteration)
        batch_losses.append(batch_loss_value)
        with torch.no_grad():
            result = calculate_metrics(model_result.cpu().numpy(), targets.cpu().numpy())
            for metric in result:
                logger.add_scalar('train/' + metric, result[metric], iteration)

        if iteration % test_freq == 0:
            model.eval()
            with torch.no_grad():
                model_result = []
                targets = []
                for inputs, batch_targets in test_dataloader:
                    inputs = inputs.to(device)
                    model_batch_result = model(inputs)
                    model_result.extend(model_batch_result.cpu().numpy())
                    targets.extend(batch_targets.cpu().numpy())

            result = calculate_metrics(np.array(model_result), np.array(targets))
            for metric in result:
                logger.add_scalar('test/' + metric, result[metric], iteration)
            print("epoch:{:2d} iter:{:3d} test: "
                  "micro f1: {:.3f} "
                  "macro f1: {:.3f} "
                  "samples f1: {:.3f}".format(epoch, iteration,
                                              result['micro/f1'],
                                              result['macro/f1'],
                                              result['samples/f1']))

            model.train()
        iteration += 1

    loss_value = np.mean(batch_losses)
    print("epoch:{:2d} iter:{:3d} train: loss:{:.3f}".format(epoch, iteration, loss_value))
    if epoch % save_freq == 0:
        checkpoint_save(model, save_path, epoch)
    epoch += 1
    if max_epoch_number < epoch:
        break

Num of available GPUs:  8


  1%|▌                                                                                | 1/153 [00:39<1:40:52, 39.82s/it]

epoch: 0 iter:  0 test: micro f1: 0.045 macro f1: 0.017 samples f1: 0.044


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [02:28<00:00,  1.03it/s]


epoch: 0 iter:153 train: loss:0.105
saved checkpoint: cnn_ckpt/checkpoint-000000.pth


 31%|█████████████████████████▋                                                        | 48/153 [00:40<05:57,  3.40s/it]

epoch: 1 iter:200 test: micro f1: 0.208 macro f1: 0.004 samples f1: 0.209


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:51<00:00,  1.38it/s]


epoch: 1 iter:306 train: loss:0.072


 62%|██████████████████████████████████████████████████▉                               | 95/153 [01:09<03:15,  3.37s/it]

epoch: 2 iter:400 test: micro f1: 0.304 macro f1: 0.009 samples f1: 0.291


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:53<00:00,  1.35it/s]


epoch: 2 iter:459 train: loss:0.071


 93%|███████████████████████████████████████████████████████████████████████████▏     | 142/153 [01:41<00:29,  2.65s/it]

epoch: 3 iter:600 test: micro f1: 0.294 macro f1: 0.006 samples f1: 0.296


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:47<00:00,  1.42it/s]


epoch: 3 iter:612 train: loss:0.071


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:47<00:00,  1.42it/s]


epoch: 4 iter:765 train: loss:0.071


 24%|███████████████████▎                                                              | 36/153 [00:31<05:44,  2.95s/it]

epoch: 5 iter:800 test: micro f1: 0.253 macro f1: 0.007 samples f1: 0.242


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:44<00:00,  1.47it/s]


epoch: 5 iter:918 train: loss:0.070
saved checkpoint: cnn_ckpt/checkpoint-000005.pth


 54%|████████████████████████████████████████████▍                                     | 83/153 [01:08<03:32,  3.03s/it]

epoch: 6 iter:1000 test: micro f1: 0.349 macro f1: 0.010 samples f1: 0.345


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:49<00:00,  1.40it/s]


epoch: 6 iter:1071 train: loss:0.070


 85%|████████████████████████████████████████████████████████████████████▊            | 130/153 [01:36<01:11,  3.12s/it]

epoch: 7 iter:1200 test: micro f1: 0.390 macro f1: 0.014 samples f1: 0.379


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:51<00:00,  1.37it/s]


epoch: 7 iter:1224 train: loss:0.069


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:29<00:00,  1.70it/s]


epoch: 8 iter:1377 train: loss:0.069


 16%|████████████▊                                                                     | 24/153 [00:25<07:06,  3.31s/it]

epoch: 9 iter:1400 test: micro f1: 0.211 macro f1: 0.007 samples f1: 0.203


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:51<00:00,  1.37it/s]


epoch: 9 iter:1530 train: loss:0.069


 46%|██████████████████████████████████████                                            | 71/153 [00:48<03:49,  2.80s/it]

epoch:10 iter:1600 test: micro f1: 0.352 macro f1: 0.013 samples f1: 0.337


100%|█████████████████████████████████████████████████████████████████████████████████| 153/153 [01:44<00:00,  1.46it/s]


epoch:10 iter:1683 train: loss:0.068
saved checkpoint: cnn_ckpt/checkpoint-000010.pth


  5%|███▊                                                                               | 7/153 [00:04<01:23,  1.74it/s]


KeyboardInterrupt: 

In [26]:
# Run inference on the test data
model.eval()
sample = 0
for inputs, targets in train_dataloader:
    sample += 1
    with torch.no_grad():
        raw_pred = model(inputs).cpu().numpy()[0]
        raw_pred = np.array(raw_pred > 0.5, dtype=float)
        print(raw_pred)
    if sample > 5:
        break

[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
