In [1]:
# stdlib
from typing import Any, List, Tuple, Union

# third party
import numpy as np
import math, sys, argparse
import pandas as pd
import torch
from torch import nn
from functools import partial
import time, os, json
from utils import NativeScaler, MAEDataset, adjust_learning_rate, get_dataset
import MAE
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import sys
import timm.optim.optim_factory as optim_factory
from utils import get_args_parser
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from math import sqrt
from sklearn.datasets import load_iris
from tqdm import tqdm
eps = 1e-8
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, confusion_matrix
from math import sqrt
import os 

# Read dataset

In [2]:
class ReMaskerStep:

    def __init__(self, dim=16, mask_ratio=0.5, max_epochs=300, warmup_epochs=20, save_path=None, model=None, device=None, weigths=None, eps = 1e-7, normalize=True, nan=-1,
                batch_size=64, accum_iter=1, min_lr=1e-5, norm_field_loss=False, 
                 weight_decay=0.05, lr=None, blr=1e-3, embed_dim=32, depth=6, 
                 decoder_depth=4, num_heads=4, mlp_ratio=4.0, encode_func='linear', **kwargs):
        #args = get_args_parser().parse_args()

        self.batch_size = batch_size
        self.accum_iter = accum_iter
        self.min_lr = min_lr
        self.norm_field_loss = norm_field_loss
        self.weight_decay = weight_decay
        self.lr = lr
        self.blr = blr
        self.warmup_epochs = warmup_epochs
        self.weigths = None
        self.dim = dim
        self.eps = 1e-7
        self.embed_dim = embed_dim
        self.depth = depth
        self.decoder_depth = decoder_depth
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.max_epochs = max_epochs
        self.mask_ratio = mask_ratio
        self.encode_func = encode_func
        self.nan = nan
        
        if not save_path:
            self.save_path = f'./checkpoints_{self.mask_ratio}'
        else:
            self.save_path = save_path
            
        os.makedirs(save_path, exist_ok=True)
            
        if not device:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        if not(model):
            ### Model ###
            self.model = MAE.MaskedAutoencoder(
                rec_len=self.dim,
                embed_dim=self.embed_dim,
                depth=self.depth,
                num_heads=self.num_heads,
                decoder_embed_dim=self.embed_dim,
                decoder_depth=self.decoder_depth,
                decoder_num_heads=self.num_heads,
                mlp_ratio=self.mlp_ratio,
                norm_layer=partial(nn.LayerNorm, eps=self.eps),
                norm_field_loss=self.norm_field_loss,
                encode_func=self.encode_func
            )
        else:
            self.model = model
            
        
        # Load Checkpoint if any
        if weigths and os.path.exists(weigths):
            self.model.load_state_dict(torch.load(weigths))
            
            
        if torch.cuda.device_count() > 1:  # Checks for multiple GPUs
            print(f"Let's use {torch.cuda.device_count()} GPUs!")
            model = nn.DataParallel(model)
        
        self.model.to(self.device)
        
        #self.normalize_vals = normalize
        self.norm_parameters = None
        

    def calculate_norm_parameters(self, X: pd.DataFrame):
        
        min_val = np.zeros(self.dim)
        max_val = np.zeros(self.dim)
        
        for i in range(self.dim):
            # Use .iloc to access the DataFrame by integer-location
            min_val[i] = np.nanmin(X.iloc[:, i])
            max_val[i] = np.nanmax(X.iloc[:, i])
        
        self.norm_parameters = {"min": min_val, "max": max_val}
        
    def normalize(self, X_raw: pd.DataFrame, return_format='torch'):
        X = X_raw.copy()
        
        if not(self.norm_parameters):
            print('calculating norm parameters...')
            self.calculate_norm_parameters(X)
            
        min_val = self.norm_parameters["min"]
        max_val = self.norm_parameters["max"]

        ### Normalization:
        for i in range(self.dim):
            # Perform the operation and update the column
            X.iloc[:, i] = (X.iloc[:, i] - min_val[i]) / (max_val[i] - min_val[i] + self.eps)

        self.norm_parameters = {"min": min_val, "max": max_val}
        
        if return_format == 'numpy':
            np_array = X.to_numpy()
            return np_array
        elif return_format == 'torch': 
            np_array = X.to_numpy()
            # Convert NumPy array to PyTorch tensor
            X = torch.tensor(np_array, dtype=torch.float32)
            return X
        else:
            return X
        
    def denormalize(self, imputed_data):
    
        min_val = self.norm_parameters["min"]
        max_val = self.norm_parameters["max"]
        
        # Renormalize
        for i in range(self.dim):
            imputed_data[:, i] = imputed_data[:, i] * (max_val[i] - min_val[i] + self.eps) + min_val[i]
            
        return imputed_data
        
        

    def fit(self, X_raw: pd.DataFrame, X_val=None, exclude_columns=None):
        
        #if self.normalize:
        X = self.normalize(X_raw)
            
        # Set missing
        M = 1 - (1 * (np.isnan(X)))
        M = M.float().to(self.device)

        X = torch.nan_to_num(X, nan=self.nan)
        X = X.to(self.device)

        # set optimizers
        # param_groups = optim_factory.add_weight_decay(model, args.weight_decay)
        eff_batch_size = self.batch_size * self.accum_iter
        if self.lr is None:  # only base_lr is specified
            self.lr = self.blr * eff_batch_size / 64
            
        # param_groups = optim_factory.add_weight_decay(self.model, self.weight_decay)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, betas=(0.9, 0.95))
        loss_scaler = NativeScaler()

        dataset = MAEDataset(X, M)
        dataloader = DataLoader(
            dataset, sampler=RandomSampler(dataset),
            batch_size=self.batch_size,
        )
        
        # To store validation results
        results_csv_path = os.path.join(self.save_path, 'validation_results.csv')
        
        ############ Train Loop ############
        for epoch in range(self.max_epochs):
            self.model.train()
            print(epoch)
            optimizer.zero_grad()
            total_loss = 0

            iter = 0
            eight = True

            for iter, (samples, masks) in tqdm(enumerate(dataloader), total = len(dataloader)):
                
                # we use a per iteration (instead of per epoch) lr scheduler
                if iter % self.accum_iter == 0:
                    adjust_learning_rate(optimizer, iter / len(dataloader) + epoch, self.lr, self.min_lr,
                                         self.max_epochs, self.warmup_epochs)
                
                
                # Add 1 dimension and send to device
                samples = samples.unsqueeze(dim=1)
                samples = samples.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)

                # Calculate the loss
                with torch.cuda.amp.autocast():
                    loss, _, _, _ = self.model(samples, masks, mask_ratio=self.mask_ratio, exclude_columns=exclude_columns)
                    loss_value = loss.item()
                    total_loss += loss_value

                if not math.isfinite(loss_value):
                    print("Loss is {}, stopping training".format(loss_value))
                    sys.exit(1)
                
                loss /= self.accum_iter
                
                # Calculate the gradient and backpropagate
                loss_scaler(loss, optimizer, parameters=self.model.parameters(),
                            update_grad=(iter + 1) % self.accum_iter == 0)
                
                # Set gradients to 0 each accum_iter iterations
                if (iter + 1) % self.accum_iter == 0:
                    optimizer.zero_grad()

            total_loss = (total_loss / (iter + 1)) ** 0.5
            
            
            ############ Validation ############
            self.model.eval()
            eight_str = str(eight)
            if epoch % 30 == 0 and X_val is not None and not X_val.empty:

                # Get a subset of data
                if epoch != (self.max_epochs-1):
                    X_test = X_val[:10000]
                else: 
                    X_test = X_val
                    
                epoch_validation_results = []
                
                # Evaluate each lab value:
                for column, column_name in enumerate(X_test.columns):
                    
                    # Ignore the time columns
                    if column in exclude_columns:
                        continue  
                    
                    # Only evaluate if the column contains values
                    X_test_real = X_test[X_test[column].notna()]
                    
                    if len(X_test_real) < 1:
                        print(f'The sampling size of test with in column: {column}, is only {len(X_test_real)}')
                        continue
                    
                    X_test_masked = X_test_real.copy()
                    # Mask all values in that column with NaN
                    X_test_masked.iloc[:,column]=np.nan

                    # Impute the values:
                    X_test_imputed =  pd.DataFrame(self.transform(X_test_masked).cpu().numpy())

                    # Classify into normal abnormal
                    #actual_classes = X_test.iloc[:, column].apply(classify_value, args=normal_ranges[column_name])
                    #predicted_classes = X_test_imputed.iloc[:, column].apply(classify_value, args=normal_ranges[column_name])

                    # Calculate the metrics:
                    #cm = confusion_matrix(actual_classes, predicted_classes, labels=['under', 'within', 'over'])

                    # Calculate RMSE, MAE, and R2
                    rmse = sqrt(mean_squared_error(X_test.iloc[:, column].dropna(), X_test_imputed.iloc[:, column].dropna()))
                    mae = mean_absolute_error(X_test.iloc[:, column].dropna(), X_test_imputed.iloc[:, column].dropna())
                    r2 = r2_score(X_test.iloc[:, column].dropna(), X_test_imputed.iloc[:, column].dropna())

                    # Construct the output string
                    #output_str = f"Epoch{epoch} Evaluation for {column_name}: RMSE = {rmse}, MAE = {mae}, R2 = {r2}, Confusion Matrix: {cm.tolist()}\n"
                    output_str = f"Epoch{epoch} Evaluation for {column_name}: RMSE = {rmse}, MAE = {mae}, R2 = {r2}\n"
                    print(output_str)
                    
                    epoch_validation_results.append({
                        'Epoch': epoch,
                        'Column': column_name,
                        'RMSE': rmse,
                        'MAE': mae,
                        'R2': r2
                    })

                results_df = pd.DataFrame(epoch_validation_results)

                # Check if file exists to determine if we need to write headers
                if not os.path.exists(results_csv_path):
                    results_df.to_csv(results_csv_path, index=False)  # Include header
                else:
                    results_df.to_csv(results_csv_path, mode='a', header=False, index=False)  # Append without header

          
            if (epoch + 1) % 10 == 0 or epoch == 0:
                print((epoch+1),',', total_loss)
                # Check if dir exists, if not, create the dir
                os.makedirs(self.save_path, exist_ok=True)
                torch.save(self.model.state_dict(), f'{self.save_path}/epoch{epoch+1}_checkpoint')
        
        return self

    def transform(self, X_raw: pd.DataFrame, eval_batch_size=None):
        
        no = X_raw.shape[0]
        
        #if self.normalize:
        X = self.normalize(X_raw)
            
        M = 1 - (1 * (np.isnan(X)))
        M = M.float().to(self.device)
        
        X = torch.nan_to_num(X, nan=self.nan)
        X = X.to(self.device)
        
        dataset = MAEDataset(X, M)
        if eval_batch_size:
            dataloader = DataLoader(
                dataset, sampler=SequentialSampler(dataset),
                batch_size=eval_batch_size, 
                drop_last=False
            )
        else:
            dataloader = DataLoader(
                dataset, sampler=SequentialSampler(dataset),
                batch_size=self.batch_size, 
                drop_last=False
            )

        self.model.eval()

        # Imputed data
        imputed_data_list = []
        with torch.no_grad():
            for sample, mask in dataloader:
                sample = sample.unsqueeze(1)
                sample.to(self.device)
                mask.to(self.device)
                _, pred, _, _ = self.model(sample, mask)
                pred = pred.squeeze(dim=2)
                imputed_data_list.append(pred)

        imputed_data = torch.cat(imputed_data_list, 0)
        imputed_data = self.denormalize(imputed_data)


        if np.all(np.isnan(imputed_data.detach().cpu().numpy())):
            err = "The imputed result contains nan. This is a bug. Please report it on the issue tracker."
            raise RuntimeError(err)

        M = M.cpu()
        imputed_data = imputed_data.detach().cpu()
        
        if not torch.is_tensor(X_raw):
            X_raw = torch.tensor(X_raw.values) 

        return M * np.nan_to_num(X_raw.cpu()) + (1 - M) * imputed_data

    def fit_transform(self, X: torch.Tensor) -> torch.Tensor:
        """Imputes the provided dataset using the GAIN strategy.
        Args:
            X: np.ndarray
                A dataset with missing values.
        Returns:
            Xhat: The imputed dataset.
        """
        X = torch.tensor(X.values, dtype=torch.float32)
        return self.fit(X).transform(X).detach().cpu().numpy()


### Read dataset

In [3]:
df = pd.read_csv('/scratch/liyues_root/liyues/chenweiw/lab_values/10_labs.csv', header=None)
print(df.shape)
df.head()

(1582939, 40)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,30,31,32,33,34,35,36,37,38,39
0,3.5,18.0,137.0,18.0,103.0,18.0,26.5,18.0,1.1,18.0,...,,,,,,,,,,
1,3.7,18.0,136.0,18.0,100.0,18.0,28.1,18.0,1.2,18.0,...,24.0,42.0,14.0,42.0,19.0,42.0,169.0,42.0,183.0,42.0
2,3.9,17.0,138.0,17.0,102.0,17.0,24.4,14.0,0.9,17.0,...,28.0,42.0,12.0,42.0,17.0,42.0,190.0,42.0,195.0,42.0
3,5.2,17.0,140.0,17.0,104.0,17.0,32.1,17.0,1.1,17.0,...,,,,,,,,,,
4,4.8,17.0,141.0,17.0,106.0,17.0,30.3,17.0,2.3,17.0,...,,,,,,,,,,


### Split the df intro train and validation

In [4]:
# Split the dataframe into train and test sets
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)  # 20% of the data as test set

# Print the shapes of the train and test dataframes
print("Train shape:", train_df.shape)
print("Test shape:", test_df.shape)

# Create a list of odd column indexes
time_column_indexes = [i for i in range(df.shape[1]) if i % 2 != 0]

# Print the odd column indexes
print("Odd column indexes:", time_column_indexes)

Train shape: (1266351, 40)
Test shape: (316588, 40)
Odd column indexes: [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39]


### Model

In [5]:
columns = df.shape[1]
mask_ratio = 0.5
max_epochs = 300
save_path = '10_Labs_Train'

In [6]:
imputer = ReMaskerStep(dim=columns, mask_ratio=mask_ratio, max_epochs=max_epochs, save_path=save_path)

### Train the model

In [None]:
imputer.fit(train_df, test_df, exclude_columns=time_column_indexes)

calculating norm parameters...
0


100%|██████████| 19787/19787 [06:35<00:00, 50.05it/s]


Epoch0 Evaluation for 0: RMSE = 0.83354032564747, MAE = 0.6551566178101119, R2 = -1.3645392779018364

Epoch0 Evaluation for 2: RMSE = 11.066361753364026, MAE = 8.973921206874094, R2 = -5.813309195725753

Epoch0 Evaluation for 4: RMSE = 9.015458321504338, MAE = 7.060514823205823, R2 = -1.835572227726325

Epoch0 Evaluation for 6: RMSE = 10.037122861733545, MAE = 8.239876821650078, R2 = -1.567106171220817

Epoch0 Evaluation for 8: RMSE = 1.5252524592524044, MAE = 0.7960259334486625, R2 = -0.0559097376409341

Epoch0 Evaluation for 10: RMSE = 5.044840142172688, MAE = 3.8950325778029473, R2 = -0.37941840295138984

Epoch0 Evaluation for 12: RMSE = 3.856253420083287, MAE = 2.951699946240759, R2 = -0.26584314717289304

Epoch0 Evaluation for 14: RMSE = 19.784864575566676, MAE = 12.96078433981272, R2 = -0.02099831796584728

Epoch0 Evaluation for 16: RMSE = 61.76746030012727, MAE = 40.27490646904438, R2 = -0.3062642741601356

Epoch0 Evaluation for 18: RMSE = 145.81379241573043, MAE = 114.104917200

100%|██████████| 19787/19787 [06:20<00:00, 51.98it/s]


2


100%|██████████| 19787/19787 [06:21<00:00, 51.83it/s]


3


100%|██████████| 19787/19787 [06:20<00:00, 51.95it/s]


4


100%|██████████| 19787/19787 [06:14<00:00, 52.82it/s]


5


100%|██████████| 19787/19787 [06:17<00:00, 52.42it/s]


6


100%|██████████| 19787/19787 [06:21<00:00, 51.86it/s]


7


100%|██████████| 19787/19787 [06:02<00:00, 54.60it/s]


8


100%|██████████| 19787/19787 [06:02<00:00, 54.56it/s]


9


100%|██████████| 19787/19787 [06:02<00:00, 54.58it/s]


10 , 0.4672753982352495
10


100%|██████████| 19787/19787 [06:02<00:00, 54.57it/s]


11


100%|██████████| 19787/19787 [06:02<00:00, 54.58it/s]


12


100%|██████████| 19787/19787 [06:02<00:00, 54.55it/s]


13


100%|██████████| 19787/19787 [06:02<00:00, 54.55it/s]


14


100%|██████████| 19787/19787 [06:02<00:00, 54.56it/s]


15


100%|██████████| 19787/19787 [06:02<00:00, 54.53it/s]


16


100%|██████████| 19787/19787 [06:02<00:00, 54.54it/s]


17


100%|██████████| 19787/19787 [06:02<00:00, 54.53it/s]


18


100%|██████████| 19787/19787 [06:02<00:00, 54.54it/s]


19


100%|██████████| 19787/19787 [06:02<00:00, 54.54it/s]


20 , 0.46725921683585364
20


100%|██████████| 19787/19787 [06:02<00:00, 54.56it/s]


21


100%|██████████| 19787/19787 [06:02<00:00, 54.57it/s]


22


100%|██████████| 19787/19787 [06:02<00:00, 54.58it/s]


23


100%|██████████| 19787/19787 [06:02<00:00, 54.52it/s]


24


100%|██████████| 19787/19787 [06:02<00:00, 54.53it/s]


25


100%|██████████| 19787/19787 [06:02<00:00, 54.53it/s]


26


100%|██████████| 19787/19787 [06:02<00:00, 54.54it/s]


27


 24%|██▎       | 4666/19787 [01:25<04:36, 54.60it/s]