In [1]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import wandb
from sklearn.metrics import precision_score
from accelerate import Accelerator
from accelerate import DistributedType
import os
from utils.utils import seed_everything
from transformers import LongformerTokenizer
from datasets import EHR_Longformer_Dataset
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
from models.longformernormal import LongformerPretrainNormal
from torch.optim.lr_scheduler import LinearLR, SequentialLR, ExponentialLR, LambdaLR, CosineAnnealingWarmRestarts
from pretrain_train import train
import logging
import sys
from torch.utils.data.distributed import DistributedSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def configure_optimizers(model, args, n_steps):
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    
    n_warmup_steps = int(n_steps * 0.1)
    n_decay_steps = n_steps - n_warmup_steps
    
    warmup = LinearLR(optimizer, 
                        start_factor=0.01,
                        end_factor=1.0,
                        total_iters=n_warmup_steps)
    
    decay = LinearLR(optimizer,
                        start_factor=1.0,
                        end_factor=0.01,
                        total_iters=n_decay_steps)
    
    scheduler = SequentialLR(optimizer, 
                                schedulers=[warmup, decay],
                                milestones=[n_warmup_steps])

    return optimizer, {"scheduler": scheduler, "interval": "step"}

In [3]:
parser = argparse.ArgumentParser()
    
# Required parameters
parser.add_argument("--exp_name", type=str, default="pretrain")
parser.add_argument("--save_path", type=str, default="./results")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints")

# Model parameters
parser.add_argument("--vocab_size", type=int, default=50265)
parser.add_argument("--itemid_size", type=int, default=4016)
parser.add_argument("--unit_size", type=int, default=60)
parser.add_argument("--gender_size", type=int, default=2)
parser.add_argument("--continuous_size", type=int, default=3)
parser.add_argument("--task_size", type=int, default=4)
parser.add_argument("--max_position_embeddings", type=int, default=5000)
parser.add_argument("--max_age", type=int, default=100)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--pin_memory", type=bool, default=True)
parser.add_argument("--nodes", type=int, default=1)
parser.add_argument("--gpus", type=int, default=2)
parser.add_argument("--start_epoch", type=int, default=0)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--log_every_n_steps", type=int, default=100)
parser.add_argument("--acc", type=int, default=1)
parser.add_argument("--resume_checkpoint", type=str, default=None)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--embedding_size", type=int, default=768)
parser.add_argument("--num_hidden_layers", type=int, default=12)
parser.add_argument("--num_attention_heads", type=int, default=6)
parser.add_argument("--intermediate_size", type=int, default=1536)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--dropout_prob", type=float, default=0.1)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--gpu_mixed_precision", type=bool, default=True)
parser.add_argument("--patience", type=int, default=5)



args = parser.parse_args([])
args.attention_window = [512] * args.num_hidden_layers

In [4]:
def compute_mlm_loss(predictions, labels):
    mask = torch.ones_like(labels)
    mask[:, :3] = 0  
    
    predictions = predictions.reshape(-1, predictions.size(-1))
    labels = labels.reshape(-1)
    mask = mask.reshape(-1)
    loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(predictions[mask.bool()], labels[mask.bool()])
    
    return loss

def calculate_mlm_precision(predictions, labels):
    mask = torch.ones_like(labels)
    mask[:, :3] = 0  
    
    predicted_labels = predictions.argmax(dim=-1)
    valid_mask = labels != -100
    
    correct_predictions = (predicted_labels == labels) & valid_mask & mask.bool()
    
    num_correct = correct_predictions.sum().item()
    num_valid = (valid_mask & mask.bool()).sum().item()
    
    precision = num_correct / num_valid if num_valid > 0 else 0.0
    
    return precision

In [6]:
seed_everything(args.seed)
    
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

itemid2idx = pd.read_pickle("datasets/entire_itemid2idx.pkl")
unit2idx = pd.read_pickle("datasets/unit2idx.pkl")
accelerator = Accelerator(mixed_precision="fp16" if args.gpu_mixed_precision else "no")
print(f"Distributed Type: {accelerator.distributed_type}")
device = accelerator.device

pretrained_model = LongformerPretrainNormal(
        vocab_size=args.vocab_size,
        itemid_size=args.itemid_size,
        max_position_embeddings=args.max_position_embeddings,
        unit_size=args.unit_size,
        continuous_size=args.continuous_size,
        task_size=args.task_size,
        max_age=args.max_age,
        gender_size=args.gender_size,
        embedding_size=args.embedding_size,
        num_hidden_layers=args.num_hidden_layers,
        num_attention_heads=args.num_attention_heads,
        intermediate_size=args.intermediate_size,
        learning_rate=args.learning_rate,
        dropout_prob=args.dropout_prob,
        gpu_mixed_precision=args.gpu_mixed_precision,
    ).to(device)

model_path = "./results/best_pretrain_model.pth"
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint['model_state_dict']

new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('module.module.'):
        new_state_dict[k[14:]] = v  
    elif k.startswith('module.'):
        new_state_dict[k[7:]] = v 
    else:
        new_state_dict[k] = v  
    
pretrained_model.load_state_dict(new_state_dict)
print("Pre-trained model loaded successfully.")

valid_dataset = EHR_Longformer_Dataset(Path("./datasets"), "valid", tokenizer, itemid2idx, unit2idx, use_itemid=True)
test_dataset = EHR_Longformer_Dataset(Path("./datasets"), "test", tokenizer, itemid2idx, unit2idx, use_itemid=True)


valid_loader = DataLoader(valid_dataset, 
                            batch_size=args.batch_size,
                            shuffle=False,  # shuffle should be False if using DistributedSampler
                            pin_memory=args.pin_memory, 
                            num_workers=args.num_workers,
                            )

test_loader = DataLoader(test_dataset, 
                            batch_size=args.batch_size, 
                            shuffle=False,  # Validation should not be shuffled
                            pin_memory=args.pin_memory, 
                            num_workers=args.num_workers,
                            )

n_steps = (len(valid_dataset) // args.batch_size) * args.epochs



Seed set to 42
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Distributed Type: NO


  checkpoint = torch.load(model_path, map_location=device)


Pre-trained model loaded successfully.


In [7]:
val_loss = []
val_precision = []

pretrained_model.eval()
with torch.no_grad():
    for step, batch in tqdm(enumerate(valid_loader), desc="Validation", total=len(valid_loader)):
        batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch)
        input_ids, attention_mask, age_ids, gender_ids, value_ids, unit_ids, time_ids, continuous_ids, position_ids, token_type_ids, task_token, labels = batch
        
        batch_size = labels.size(0)
        additional_tokens = torch.tensor([1, 1, 1]).unsqueeze(0).repeat(batch_size, 1).to(device)
        labels = torch.cat([additional_tokens, labels], dim=1)
        # with torch.autocast(device_type=device.type, dtype=torch.float16):
        outputs = pretrained_model(
            input_ids = input_ids,
            value_ids = value_ids,
            unit_ids = unit_ids,
            time_ids = time_ids,                
            continuous_ids = continuous_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            age_ids = age_ids,
            gender_ids = gender_ids,
            task_token = task_token,
            attention_mask=attention_mask,
            global_attention_mask=None,
            labels=labels,
            return_dict=True,)
            
        prediction_scores = outputs.logits[:, 3:, :]
        labels = labels[:, 3:]
        
        loss = compute_mlm_loss(prediction_scores, labels)
        precision = calculate_mlm_precision(prediction_scores, labels)
        val_precision.append(precision)
        val_loss.append(loss.item())
    
    
    
    print(f"Validation Loss: {np.mean(val_loss)}")
    print(f"Validation Precision: {np.mean(val_precision)}")

Validation: 100%|██████████| 673/673 [13:52<00:00,  1.24s/it]

Validation Loss: 0.6248296253039968
Validation Precision: 0.7954707614620776





In [8]:
test_loss = []
test_precision = []

pretrained_model.eval()
with torch.no_grad():
    for step, batch in tqdm(enumerate(test_loader), desc="Test", total=len(test_loader)):
        batch = tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in batch)
        input_ids, attention_mask, age_ids, gender_ids, value_ids, unit_ids, time_ids, continuous_ids, position_ids, token_type_ids, task_token, labels = batch
        
        batch_size = labels.size(0)
        additional_tokens = torch.tensor([1, 1, 1]).unsqueeze(0).repeat(batch_size, 1).to(device)
        labels = torch.cat([additional_tokens, labels], dim=1)
        # with torch.autocast(device_type=device.type, dtype=torch.float16):
        outputs = pretrained_model(
            input_ids = input_ids,
            value_ids = value_ids,
            unit_ids = unit_ids,
            time_ids = time_ids,                
            continuous_ids = continuous_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            age_ids = age_ids,
            gender_ids = gender_ids,
            task_token = task_token,
            attention_mask=attention_mask,
            global_attention_mask=None,
            labels=labels,
            return_dict=True,)
            
        prediction_scores = outputs.logits[:, 3:, :]
        labels = labels[:, 3:]
        
        loss = compute_mlm_loss(prediction_scores, labels)
        precision = calculate_mlm_precision(prediction_scores, labels)
        test_precision.append(precision)
        test_loss.append(loss.item())
    
    
    print(f"Test Loss: {np.mean(test_loss)}")
    print(f"Test Precision: {np.mean(test_precision)}")

Test: 100%|██████████| 673/673 [13:54<00:00,  1.24s/it]

Test Loss: 0.6274875504743823
Test Precision: 0.7938126986898193



