In [None]:
from glob import glob
import torch
from torch import nn
import time
import wandb
import pandas as pd
import numpy as np
from tqdm import tqdm
from IPython.display import display
from torch.utils.data import Dataset,DataLoader
from torch.cuda.amp import autocast, GradScaler
import warnings
from sklearn.metrics import roc_auc_score, log_loss
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
import warnings; warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, confusion_matrix, roc_auc_score, precision_score, recall_score

from pylab import rcParams
rcParams['figure.figsize'] = 20,5

In [None]:
train_df = pd.read_csv('C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/train2_meta.csv', index_col=0)
train_signal = pd.read_csv('C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/train2_signal.csv')
valid_df = pd.read_csv('C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/valid2_meta.csv', index_col=0)
valid_signal = pd.read_csv('C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/valid2_signal.csv')
test_df = pd.read_csv('C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/test2_meta.csv', index_col=0)
test_signal = pd.read_csv('C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/test2_signal.csv')

In [None]:
class PTBXLDatasetPreprocesser():
    def __init__(self):
        pass
    
    def save(self, filename):
        data = {
            'class_cols': self.class_cols,    
            'meta_num_cols': self.meta_num_cols,
            'meta_num_means': self.meta_num_means,
            'min_max_scaler': self.min_max_scaler,
            'meta_cat_cols': self.meta_cat_cols,
            'cat_lablers': self.cat_lablers,
        }
        pd.to_pickle(data, filename)
        
    def load(self, filename):
        data = pd.read_pickle(filename)
        self.min_max_scaler = data['min_max_scaler']
        self.cat_lablers = data['cat_lablers']
        
    def fit(self, x, y):
        x = x.copy()
        y = y.copy()
        
        self.class_cols = ['NORM', 'MI', 'STTC', 'CD', 'HYP']
        
        
        
        self.meta_num_cols = ['age', 'height', 'weight']
        self.meta_num_means = []
        for col in self.meta_num_cols:
            print(col, y[col].mean())
            y[col] = y[col].fillna(y[col].mean())
            self.meta_num_means += [y[col].mean()]
            
        self.min_max_scaler = MinMaxScaler().fit(y[self.meta_num_cols])
        
        self.meta_cat_cols = ['sex'] #, 'nurse', 'device']
        self.cat_lablers = [LabelEncoder().fit(y[col].fillna('none').astype(str)) for col in self.meta_cat_cols]
        return self
    
    def transform(self, x, y):
        
        channel_cols = x.columns.tolist()[1:]
        
        ret = []
        x = x[channel_cols].values.reshape(-1, 1000, 12)
        print(x.shape)
        ret += [x] # signal
        
        y_ = y.copy()
        
        for i, col in enumerate(self.meta_num_cols):
            y_[col] = y_[col].fillna(self.meta_num_means[i])
        y_[self.meta_num_cols] = self.min_max_scaler.transform(y_[self.meta_num_cols])
        y_[self.meta_num_cols] = np.clip(y_[self.meta_num_cols], 0., 1.) # prevent extreme value far from train set
        
        ret += [y_[self.meta_num_cols]] # meta num features
        
        for i, col in enumerate(self.meta_cat_cols):
            y_[col] = y_[col].fillna('none').astype(str)
            y_[col] = self.cat_lablers[i].transform(y_[col]) 
        
        ret += [y_[self.meta_cat_cols]] # meta cat features
        
        if np.isin(self.class_cols, y.columns).sum() == len(self.class_cols):
            ret += [y[self.class_cols].fillna(0).astype(int)] # class targets
        
     
        
        return ret

In [None]:
data_preprocessor = PTBXLDatasetPreprocesser()
data_preprocessor.fit(train_signal, train_df)
train_signal, train_meta_num_feats, train_meta_cat_feats, train_class  = data_preprocessor.transform(train_signal, train_df)
valid_signal, valid_meta_num_feats, valid_meta_cat_feats, valid_class  = data_preprocessor.transform(valid_signal, valid_df)
test_signal, test_meta_num_feats, test_meta_cat_feats, test_class  = data_preprocessor.transform(test_signal, test_df)

In [None]:
class ECGDataset(Dataset):
    def __init__(self, signals, num_metas, cat_metas, class_labels=None):
        self.signals = signals
        self.num_metas = num_metas
        self.cat_metas = cat_metas
        self.class_labels = class_labels
        
        
    def __len__(self):
        return self.signals.shape[0]
    
    def __getitem__(self, idx):
        
        ret = []
        ret += [self.signals[idx,:]]
        ret += [self.num_metas.values[idx,:]]
        ret += [self.cat_metas.values[idx,:]]
        
        if self.class_labels is not None:
            ret += [self.class_labels.values[idx,:]]        

        
        return ret
    
class LSTM(nn.Module):
    def __init__(self, signal_channel_size, lstm_hidden_size, per_cat_nunique, embed_size, num_size, hidden, n_outs):
        super().__init__()

        #self.lstm1 = nn.LSTM(signal_channel_size, lstm_hidden_size, batch_first=True, bidirectional=True)
        self.rnn2 = nn.RNN(signal_channel_size, lstm_hidden_size, batch_first=True, bidirectional=True)
        #self.gru3 = nn.GRU(signal_channel_size, lstm_hidden_size, batch_first=True, bidirectional=True)
        self.embeds = []
        self.per_cat_nunique = per_cat_nunique
        for v in self.per_cat_nunique:
            self.embeds += [nn.Embedding(v, embed_size)]
        self.embeds = nn.ModuleList(self.embeds)

        self.dense1 = nn.Linear(lstm_hidden_size*4 + embed_size*len(per_cat_nunique) + num_size, hidden)
        self.relu = nn.ReLU()
        self.out = nn.Linear(hidden, n_outs)

    def forward(self, signal, num_meta, cat_meta):
        signal = signal.view(signal.shape[0], signal.shape[1], -1)
        #signal, _ = self.lstm1(signal)
        signal, _ = self.rnn2(signal)
        #signal, _ = self.gru3(signal)

        avg_pool = torch.mean(signal, 1)
        max_pool, _ = torch.max(signal, 1)

        cat_feats = []
        for i, embed in enumerate(self.embeds):
            cat_feats += [embed(cat_meta[:,i].long())]
        cat_feats = torch.cat(cat_feats, 1)

        x = torch.cat([avg_pool, max_pool, cat_feats, num_meta], 1)
        x = self.dense1(x)
        x = self.relu(x)
        x = self.out(x)

        return x


    

def prepare_dataloader(signal, meta_num_feats, meta_cat_feats, targetclass):
    
    ds = ECGDataset(signal, meta_num_feats, meta_cat_feats, class_labels=targetclass)
    
    dl = torch.utils.data.DataLoader(
        ds,
        batch_size=128,
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=0,
        
    )
    return dl

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (signal, num_meta, cat_meta, class_labels) in pbar:
        signal = signal.to(device).float()
        num_meta = num_meta.to(device).float()
        cat_meta = cat_meta.to(device).long()
        
        class_labels = class_labels.to(device).long()
        
        labels = torch.cat([class_labels], 1).float()
        
        
        with autocast():
            preds = model(signal, num_meta, cat_meta)   
            loss = loss_fn(preds, labels)
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad() 
        wandb.log({ 'loss2': loss.item()})
        if scheduler is not None and schd_batch_update:
            scheduler.step()

        description = f'epoch {epoch} loss: {running_loss:.4f}'

        pbar.set_description(description)
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    preds_all = []
    targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (signal, num_meta, cat_meta, class_labels) in pbar:
        signal = signal.to(device).float()
        num_meta = num_meta.to(device).float()
        cat_meta = cat_meta.to(device).long()
        
        class_labels = class_labels.to(device)
        
        labels = torch.cat([class_labels], 1).float()
        
        preds = model(signal, num_meta, cat_meta)  
        
        preds_all += [preds.detach().cpu().numpy()]
        targets_all += [labels.detach().cpu().numpy()]
        
        loss = loss_fn(preds, labels)
        
        loss_sum += loss.item()*labels.shape[0]
        sample_num += labels.shape[0]  
       
        description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
        pbar.set_description(description)
    
    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)
    class_cnt = class_labels.shape[1]
    binary_preds_all = (preds_all >= 0.5)
    precision = precision_score(targets_all, binary_preds_all, average='macro')
    f1 = f1_score(targets_all, binary_preds_all, average='macro')
    recall = recall_score(targets_all, binary_preds_all, average='macro')
    conf_mat = confusion_matrix(targets_all[:,:class_cnt].argmax(axis=1), binary_preds_all[:,:class_cnt].argmax(axis=1))
    
    
    wandb.log({'epoch': epoch, 'loss': loss_sum/sample_num, 
                'accuracy':(targets_all==(preds_all >= 0.5)).mean(), 'auc-roc':roc_auc_score(targets_all, preds_all, average='macro'),
                  'precision': precision, 'F1 score': f1, 'recall': recall, 'Confusion matrix': conf_mat })
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()
            
    return targets_all, preds_all

In [None]:
train_loader = prepare_dataloader(train_signal, train_meta_num_feats, train_meta_cat_feats, train_class )
val_loader = prepare_dataloader(valid_signal, valid_meta_num_feats, valid_meta_cat_feats, valid_class)

In [None]:
if __name__ == '__main__':
    wandb.init(project="rnn2")
    
    train_loader = prepare_dataloader(train_signal, train_meta_num_feats, train_meta_cat_feats, train_class)
    val_loader = prepare_dataloader(valid_signal, valid_meta_num_feats, valid_meta_cat_feats, valid_class)
    
    epochs = 1
    stepsize= 2
    hidden_sie = 10
    
    
    
    
    

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    per_cat_nunique = [len(labeler.classes_) for labeler in data_preprocessor.cat_lablers]
    model = LSTM(train_signal.shape[2], 128, per_cat_nunique, 30, train_meta_num_feats.shape[1], 128, 
                          train_class.shape[1]).to(device)
    scaler = GradScaler()   
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=stepsize)

    loss_tr = nn.BCEWithLogitsLoss().to(device)
    loss_fn = nn.BCEWithLogitsLoss().to(device)

    for epoch in range(epochs):
        train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, schd_batch_update=False)

        with torch.no_grad():
            val_targets, val_preds = valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False)

    torch.save(model.state_dict(),'C:/Users/sande/OneDrive/Documents/SKOLE/master/pre/pytorch_ecg_rnn2.pth')

    del model, optimizer, train_loader, val_loader, scaler, scheduler
    torch.cuda.empty_cache()