# AAECG

Train procedure of the AAECG - Adversarial Auto Encoder adapted to recognize irregular heartbeats.

## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import os
from tqdm import tqdm
import pickle
import multiprocessing
from torch.utils.data.dataset import Dataset
import glob
import random

## Some Utility function

In [None]:
num_cores = multiprocessing.cpu_count()

def get_beat_info(beat_path):
    with open(beat_path+".pkl", "rb") as pkl_handle:
        res = pickle.load(pkl_handle)
    del res['medications']
    return res
    
def plot_some_beat(beat_list, titles_list = None, suptitle = None, savefig = False, file= None):
    n = len(beat_list)
    nl = len(beat_list[0])
    f, ax = plt.subplots(nl, n, sharey=True)
    if not(suptitle is None):
        f.suptitle(suptitle)
    for i in range(n):
        if not (titles_list is None):
            if nl > 1:
                ax[0,i].set_title(titles_list[i])
            else:
                ax[i].set_title(titles_list[i])
        if nl == 1:
            ax[i].plot(beat_list[i][0])
            ax[i].set_ylim(bottom = -1, top = 1)
        else:
            for j in range(nl):
                ax[j,i].plot(beat_list[i][j])
                ax[j,i].set_ylim(bottom = -1, top = 1)
    if savefig:
        plt.savefig(file, dpi = 300)
    return ax

# ECG dataset class which implements the Pytorch Dataset interface

In [None]:
class Ecgdataset(Dataset):
    def __init__(self, folder_path, db = None, upload_on_ram = False, batch_size = 500, shuffle = True):
        if not (db is None):
            self.upload_on_ram = True
            self.db = db
            if shuffle:
                random.shuffle(self.db)
            self.data_len = len(db)
            return
        
        self.folder_path = folder_path
        # Get ecg list
        self.beat_list = glob.glob(folder_path + '*.npy')
        
        # Calculate len
        self.data_len = len(self.beat_list)
        
        # Upload on Ram to speed up (use only if your hardware has enough memory)
        self.upload_on_ram = upload_on_ram
        
        if self.upload_on_ram:
            print("\n uploading "+folder_path+" on RAM")
            self.db = Parallel(n_jobs=-1, verbose = 1, backend="multiprocessing", batch_size= batch_size)(delayed(self.load_index)(i)
                           for i in range(self.data_len)) 
            if shuffle:
                random.shuffle(self.db)
    def get_db(self):
        return self.db
    
    def encode_sex(self, sex):
        res = [0, 0, 1]
        if sex == 'F':
            res = [1, 0, 0]
        elif sex == 'M':
            res = [0, 1, 0]
        return torch.tensor(res, dtype = torch.float32)
    
    def load_index(self, index):
        beat_path = self.beat_list[index]
        
        # Get ecg data
        beat = torch.from_numpy(np.load(beat_path)).float()
        
        # Get ecg labels. The file name is of the type Rxxx_xxx.npy
        info = get_beat_info(beat_path.split('.')[0])
        
        sex = info['sex']
        sex = self.encode_sex(sex)
        label = sex
        
        return (beat, label, info)
    
    def __getitem__(self, index):
        if self.upload_on_ram:
            return self.db[index]
        else:
            return self.load_index(index)
    
    def __len__(self):
        return self.data_len

# AAECG class

In [None]:
import torch.nn as nn
import torch
import torch.optim as optim
import numpy as np
from statsmodels.distributions.empirical_distribution import ECDF
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import multiprocessing
from joblib import Parallel, delayed
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, precision_recall_curve, fbeta_score
from tqdm import tqdm
num_cores = multiprocessing.cpu_count()
import pickle
from torch.optim.lr_scheduler import ReduceLROnPlateau


class CBeatAAE:
    #ENCODER
    class Encoder(nn.Module):
        def __init__(self, ngpu, L, N, nef, nz):
            super(CBeatAAE.Encoder, self).__init__()
            self.ngpu = ngpu
            self.nz = nz
            self.conv = nn.Sequential(
                # input Beat dimensions L x N => L x N 
                nn.Conv1d( L, nef, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # state size (nef) x 128 
                nn.Conv1d( nef, nef * 2, 4, 2, 1, bias=False),
                nn.BatchNorm1d(nef*2),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (nef*2) x 64
                nn.Conv1d(nef * 2, nef * 4, 4, 2, 1, bias=False),
                nn.BatchNorm1d(nef*4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (nef*4) x 32
                nn.Conv1d( nef * 4, nef * 8, 3, 2, 0, bias=False),
                nn.BatchNorm1d(nef*8),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (nef*8) x 16
                nn.Conv1d( nef * 8, nef * 16, 3, 2, 0, bias=False),
                nn.BatchNorm1d(nef*16),
                nn.LeakyReLU(0.2, inplace=True)
                # state size. (nef * 16) x 8
                )
            self.mu = nn.Sequential(
                nn.Conv1d( nef * 16, nz, 8, 1, 0),
                # state size. (nef * 32) x 1
                nn.Flatten()
                )
            self.logvar = nn.Sequential(
                nn.Conv1d( nef * 16, nz, 8, 1, 0),
                # state size. (nef * 32) x 1
                nn.Flatten()
                )
            
        
        def reparametrization(self, mu, logvar):
            std = torch.exp(logvar/2)
            sampled_z = torch.randn((mu.size(0), self.nz), device = mu.device)
            z = sampled_z * std + mu
            
            return z
        
        def features(self, x):
            return self.conv(x)
        
        def forward(self, x):
            f = self.conv(x)
            return self.reparametrization(self.mu(f), self.logvar(f))
        
        def variance(self, x):
            f = self.conv(x)
            std = torch.exp(self.logvar(f)/2)
            return std
    
    #DECODER
    
    class Decoder(nn.Module):
        def __init__(self, ngpu, nef, nz, nc, L):
            super(CBeatAAE.Decoder, self).__init__()
            self.ngpu = ngpu
            self.upconv = nn.Sequential(
                # state size (nef*16) x 8 
                nn.ConvTranspose1d( nz + nc, nef * 16, 8, 1, 0, bias=False),
                nn.BatchNorm1d(nef * 16),
                nn.ReLU(True),
                # state size. (nef*8) x 16 
                nn.ConvTranspose1d(nef * 16, nef * 8, 3, 2, 0, bias=False),
                nn.BatchNorm1d(nef * 8),
                nn.ReLU(True),
                # state size. (nef*4) x 32 
                nn.ConvTranspose1d( nef * 8, nef * 4, 3, 2, 0, bias=False),
                nn.BatchNorm1d(nef * 4),
                nn.ReLU(True),
                # state size. (nef*2) x 64 
                nn.ConvTranspose1d( nef * 4, nef * 2, 4, 2, 1, bias=False),
                nn.BatchNorm1d(nef*2),
                nn.ReLU(True),
                # state size. 
                nn.ConvTranspose1d( nef * 2, nef, 4, 2, 1, bias=False),
                nn.BatchNorm1d(nef),
                nn.ReLU(True),
                # state size. (nef) x 128 
                nn.ConvTranspose1d( nef, L, 4, 2, 1, bias=False),
                # state size. (L) x 256 
                nn.Tanh()
            )
    
        def forward(self, z, labels):
            l = torch.cat((z,labels), dim = 1)
            return self.upconv(l.unsqueeze(2))
    
    
    # DISCRIMINATOR
    
    class Discriminator_prior(nn.Module):
        def __init__(self, ngpu, nz, ndf):
            super(CBeatAAE.Discriminator_prior, self).__init__()
            self.ngpu = ngpu
            self.main = nn.Sequential(
                nn.Linear(in_features=(nz), out_features=ndf),
                nn.LeakyReLU(0.2),
                
                nn.Linear(in_features=ndf, out_features=ndf//2),
                nn.LeakyReLU(0.2),
                
                nn.Linear(in_features=ndf//2, out_features=ndf//4),
                nn.LeakyReLU(0.2),
                
                nn.Linear(in_features=ndf//4, out_features=1)
            )
    
        def forward(self, input):
            return self.main(input)
    
    
    def __init__(self, device = torch.device('cpu'), N=280, L=1, nz=3,nc=3,nef=32,ndf=32, ngpu = 0):
        # Lenght of a beat.
        self.N = N
        
        # Number of leads
        self.L = L
        
        # Size of z latent vector 
        self.nz = nz
        
        # Size of additional information vector, one hot encoded sex value
        self.nc = nc # Example: [0, 1, 0] => [male, female, unlabeled] 
        
        # Size of feature maps in encoder
        self.nef = nef
        
        # Size of feature maps in discriminator
        self.ndf = ndf
        
        self.device = device
        
        self.netE = self.Encoder(ngpu, L, N,nef, nz) 
        self.netD = self.Decoder(ngpu, nef, nz, nc, L)
        self.netDis_prior = self.Discriminator_prior(ngpu, nz, ndf)
        
        self.netE.to(device)
        self.netD.to(device)
        self.netDis_prior.to(device)
        
        # Initialize
        self.initialize()
        
        self.trained = False
        
        # Initialize the threeshold
        self.threeshold = 0
        
        # Covariance
        self.invSigma = np.diag(np.ones(self.N*self.L))
        
    def initialize(self):
        self.netE.apply(self.weights_init)
        self.netD.apply(self.weights_init)
        self.netDis_prior.apply(self.weights_init)
        
    def sample_noise(self, b_s, uniform_noise = False):
        if uniform_noise:
            # Uniform Distribution [-1, 1]
            a = -1
            b = 1
            noise = (a - b) * torch.rand((b_s, self.nz), device=self.device, dtype=torch.float) + b
        else:
            noise = torch.randn((b_s, self.nz), device = self.device, dtype = torch.float)
        return noise
    
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
        elif classname.find('Linear') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
    def plot_error_mixing(self, errors, labels, best_thr, est_thr = None,):
        plt.figure()
        h = np.histogram(errors[np.where(labels == 0)])
        plt.plot(h[1][:-1], h[0]/np.sum(h[0]), label = 'normal')
        h = np.histogram(errors[np.where(labels == 1)])
        plt.plot( h[1][:-1], h[0]/np.sum(h[0]), label = 'abnormal')
        plt.ylim([0, 1.001])
        plt.xlim([0, np.max(errors)])
        if est_thr is None:
            plt.vlines(x = best_thr, ymin = 0, ymax = 1, colors = 'b', label = 'threeshold')
        else:
            plt.vlines(x = est_thr, ymin = 0, ymax = 1, colors = 'b', label = 'alpha-thr')
            plt.vlines(x = best_thr, ymin = 0, ymax = 1, colors = 'r', label = 'best f1-thr')
        plt.legend()
        plt.show()
    
    def get_threeshold(self, errors, labels = None, alpha = 0.05, verbose = False):
        # Estimate threeshold only with normal data
        # fixing false positive rate to alpha.
        # Then, if provided, use the labels to compute the threeshold
        # which gives the best recall ( = accuracy)
        # if the threeshold is greater than the previously computed
        # that means that we have decreased also the false positive rate
        # Hence we select the new threeshold
        
        if verbose:
            print("estimate threeshold with only normal data")
        if labels is None:
            ecdf = ECDF(errors)
        else:
            ecdf = ECDF(errors[np.where(labels == 0)])
        thr = ecdf.x[np.where(ecdf.y >= (1-alpha))[0][0]]
        if labels is None:
            return thr, _
        quantile_thr = thr
        best_thr = 0
        # Estimate threeshold with both
        thrs = np.linspace(np.min(errors), np.max(errors), num=errors.shape[0]*2)
        if verbose:
            print("estimate threeshold with all data")
        e = Parallel(n_jobs=num_cores)(delayed(fbeta_score)(errors >= t, labels, beta = 2)
                               for t in thrs)
        ind = np.argmax(np.array(e))
        best_thr = thrs[ind]
        if verbose:
           self.plot_error_mixing(errors, labels, best_thr, est_thr = thr)
        
        return quantile_thr, best_thr
    
    
    def get_anomaly_score(self, X, label, L=1):
        # X is b_s x L x N
        b_s = X.size(0) # Batch size
        AS = 0
        X_det = X.view(b_s,-1).detach().cpu().numpy()
        for l in range(L):
            X_rec = self.netD(self.netE(X), label).view(b_s,-1).detach().cpu().numpy()
            AS += np.mean(np.square(X_rec - X_det), axis = 1)
        return AS/L
    
    def get_rec_errors(self, X, label):
        b_s = X.size(0) # Batch size
        X_rec = self.netD(self.netE(X), label).view(b_s,-1).detach().cpu().numpy()
        X = X.view(b_s,-1).detach().cpu().numpy() # (b_s, NxL)
        rec_errors = np.square(X_rec-X)
        return rec_errors

    
    def compute_gradient_penalty(self, real_samples, fake_samples, labels = None):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        if labels is None:
            alpha = torch.rand((real_samples.size(0), 1), device = self.device)
        else:
            alpha = torch.rand((real_samples.size(0), 1, 1), device = self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        if labels is None:
            d_interpolates = self.netDis_prior(interpolates)
        fake = torch.ones((real_samples.size(0), 1), device = self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
    
    
    def train(self, train_dl, valid_dl, num_epochs, lr = 0.0001, beta1=0.5, beta2 = 0.9, lambda_rec = 1,
              lambda_adv = 1, lambda_gp = 10, lambda_tv = 0.0001,
               n_critic = 5, model_folder = 'models/AAE/'):                                                                             
        
        beat_list = []
        E_losses = []
        MSE_losses = []
        ADV_losses = []
        Disc_prior_losses = []
        Disc_visible_losses = []
        TV_losses = []
        best_metric = 0
        if not os.path.exists(model_folder):
            os.makedirs(model_folder)
            
        fixed_labels = torch.tensor([[1,0,0],
                      [1,0,0],
                      [1,0,0],
                      [0,1,0],
                      [0,1,0],
                      [0,1,0]]).to(self.device)
        fixed_noise = self.sample_noise(len(fixed_labels))
        
        ## Initialize Loss functions
        MSE = nn.MSELoss().to(self.device)
        
        # Setup Adam optimizers for both Enc, Dec and Disc 
        optimizerE = optim.Adam(self.netE.parameters(), lr=lr, betas = (beta1, beta2))
        optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas = (beta1, beta2))
        optimizerDis_prior = optim.Adam(self.netDis_prior.parameters(), lr=lr, betas = (beta1, beta2))
        schedulers = [ReduceLROnPlateau(optimizerE, patience = 10),
                     ReduceLROnPlateau(optimizerD, patience = 10),
                     ReduceLROnPlateau(optimizerDis_prior, patience = 10)]
                    
        
        print("Starting Training Loop...")
        # For each epoch
        for epoch in range(num_epochs):
            try:
                # For each batch in the dataloader
                for i, data in enumerate(train_dl, 0):
                    # Format batch
                    X = data[0].to(self.device)
                    label = data[1].to(self.device)
                    b_size = X.size(0)
                    ############################
                    # (2) Update Discriminator network on latent: minimize wasserstein distance
                    ###########################
                    for _ in range(n_critic):
                        self.netDis_prior.zero_grad()
                        data_n = next(iter(train_dl))
                        X_n = data_n[0].to(self.device)
                        Z_prior = self.sample_noise(b_size) 
                        Z_latent = self.netE(X_n).detach()
                        Disc_z_prior = self.netDis_prior(Z_prior).view(-1)
                        Disc_z_real = torch.mean(Disc_z_prior)
                        Disc_z = self.netDis_prior(Z_latent).view(-1)
                        Disc_z_fake = torch.mean(Disc_z)
                        gp = self.compute_gradient_penalty(Z_prior, Z_latent)
                        W_distance = Disc_z_real - Disc_z_fake 
                        loss = -1*W_distance + lambda_gp * gp
                        # Calculate gradients for Disc in backward pass
                        loss.backward()
                        optimizerDis_prior.step()
                    
                    ##############################
                    # (1) Update D and E networks:
                    ##############################
                    self.netE.zero_grad()
                    self.netD.zero_grad()
                    
                    Z = self.netE(X)
                    X_rec = self.netD(Z, label)
                    X_rec_detached = self.netD(Z.detach(), label)
                    
                    # total variation reg
                    TV_loss = torch.div(torch.abs(X_rec_detached[:,:,1:] - X_rec_detached[:,:,:-1]).sum(), b_size)
                    
                    # rec losses
                    ED_mse =  MSE(X, X_rec) 

                    # adv latent loss
                    E_advz =  torch.mean(self.netDis_prior(Z).view(-1))
                    
                    total_loss = lambda_rec * ED_mse + lambda_tv * TV_loss -1* lambda_adv *E_advz
                    total_loss.backward()
                    optimizerE.step()
                    optimizerD.step()
                    
                    ## Output training stats
                    if i % 50 == 0:
                        print('[%d/%d][%d/%d]\n\t mse Loss: %.4f \n\t  Adv loss Encoder: %.4f \n\t Loss Disc prior: %.4f  fake: %.4f real: %.4f \n\t  TV loss %.4f'
                              % (epoch, num_epochs, i, len(train_dl),
                                 ED_mse.item(), E_advz.item(), W_distance.item(),Disc_z_fake.item(), Disc_z_real.item(), TV_loss.item()))
            
                    # Save Losses for plotting later
                    E_losses.append(E_advz.item())
                    MSE_losses.append(ED_mse.item())
                    Disc_prior_losses.append(W_distance.item())
                    TV_losses.append(TV_loss.item())
            except KeyboardInterrupt:
                print("\n INTERRUPT DETECTED: training stopped")
                break
            # evaluate and save best model
            print("\n Evaluate")
            self.compute_threeshold(valid_dl, optimal=True)
            labels, pred = self.evaluate(valid_dl)
            metric = average_precision_score(labels, pred)
            for s in schedulers:
                s.step(-1*metric)
            if metric >= best_metric:
                print("\n %.4f -----> %.4f "%(best_metric, metric))
                best_metric = metric
                self.save(model_folder)
            else:
                print("\n %.4f less than best %.4f"%(metric, best_metric))
            # Check how the generator is doing by saving G's output on fixed_noise
            with torch.no_grad():
                fake = self.netD(fixed_noise, fixed_labels).detach().cpu()
                plot_some_beat(fake, savefig=True, file = model_folder+"/epoch_"+str(epoch)+".png")
                if epoch % 10 == 0:
                    beat_list.append(fake)
            
        
        self.load_model(model_folder)
        
        f, ax = plt.subplots(3, 1, sharex = True)
        ax[0].set_title("Adversarial Losses ")
        ax[0].plot(MSE_losses, label='mse error')
        ax[0].plot(ADV_losses, label = "adv loss")
        ax[0].plot(E_losses, label = 'Adv aggregate error')
        ax[0].legend()
        ax[1].set_title("Discriminator losses")
        ax[1].plot(Disc_prior_losses, label = 'latent')
        ax[1].plot(Disc_visible_losses, label = 'visible')
        ax[1].legend()
        ax[1].set_ylabel("Loss")
        ax[2].plot(TV_losses, label= "total variation")
        ax[2].legend()
        plt.xlabel("iterations")
        plt.show()
        for i in range(len(beat_list)):
            plot_some_beat(beat_list[i], suptitle = 'epoch: '+str(i))
        self.trained = True
        
    def save(self, folder = '/'):
        torch.save(self.netDis_prior.state_dict(), folder+"discriminator_prior.mod")
        torch.save(self.netE.state_dict(), folder+"encoder.mod")
        torch.save(self.netD.state_dict(), folder+"decoder.mod")
        np.save(folder+"thr.npy", np.array([self.threeshold]))
        
        
    def compute_threeshold(self, dl, optimal = False, verbose = False):
        pred = None
        labels = None
        if optimal:
            print("\n\t Computing the optimal threeshold")
        else:
            print("\n\t Computing thr")
        if verbose:       
            pbar = tqdm(total = len(valid_dl))
        for i,data in enumerate(valid_dl, 0):
            beats = data[0].to(self.device)
            codes = data[1].to(self.device)
            label = np.array(data[2]['label']) == 'abnormal' # 0 normale 1 anormale
            ano_sc = self.get_anomaly_score(beats, codes)
            if pred is None:
                pred = ano_sc
                labels = label
            else:
                pred = np.concatenate((pred, ano_sc), axis = 0)
                labels = np.concatenate((labels, label), axis = 0)
            if verbose:
                pbar.update(n = 1)
        if verbose:
            pbar.close()
        labels = np.array(labels)
        # Compute threeshold using both errors
        if optimal:
            quantile_thr, best_thr = self.get_threeshold(pred, labels = labels, verbose = verbose)
            self.threeshold = best_thr
        else:
            quantile_thr, _ = self.get_threeshold(pred, labels = None, verbose = verbose)
            self.threeshold = quantile_thr
            
        
    def evaluate(self, dl, verbose = False):
        pred = None
        labels = None
        if verbose:
            pbar = tqdm(total=len(dl))
        for i,data in enumerate(test_dl, 0):
            beats = data[0].to(self.device)
            codes = data[1].to(self.device)
            label = np.array(data[2]['label']) == 'abnormal' # 0 normale 1 anormale
            ano_sc = self.get_anomaly_score(beats, codes)
            if pred is None:
                pred = ano_sc
                labels = label
            else:
                pred = np.concatenate((pred, ano_sc), axis = 0)
                labels = np.concatenate((labels, label), axis = 0)
            if verbose:
                pbar.update(n=1)
        if verbose:
            pbar.close()
        return labels, pred
        
    def test(self, test_dl, result_folder = 'models/AAE/'):
        if not self.trained:
            print("train or load a model")
            return None
        print("\n Starting Testing Loop... \n")
        labels,pred = self.evaluate(test_dl)
        
        clrp = classification_report(labels, pred >= self.threeshold,
                                         target_names=['normal', 'abnormal'])
        print("\n")
        print(clrp)
        self.plot_error_mixing(pred, labels, self.threeshold)
        clrp = classification_report(labels, pred >= self.threeshold,
                                         target_names=['normal', 'abnormal'], output_dict = True)
        auc = roc_auc_score(labels, pred)
        pr_auc = average_precision_score(labels, pred)
        result = {'rep':clrp['abnormal'], 'roc_auc':auc, 'pr_auc':pr_auc,
                  'pr_curve':precision_recall_curve( labels, pred), 
                  'f2score': fbeta_score(labels, pred >= self.threeshold, beta = 2)}
        
        with open(result_folder+'result.pickle', 'wb') as handle:
            pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        return classification_report(labels, pred >= self.threeshold,
                                         target_names=['normal', 'abnormal'],
                                         output_dict=True)
    def load_model(self, model_folder='/'):
        self.netE.load_state_dict(torch.load(model_folder+"encoder.mod", map_location=self.device))
        self.netD.load_state_dict(torch.load(model_folder+"decoder.mod", map_location=self.device))
        self.netDis_prior.load_state_dict(torch.load(model_folder+"discriminator_prior.mod", map_location = self.device))
        # Load the threeshold
        self.threeshold = np.load(model_folder+"thr.npy")
        self.trained = True

## Inizialize the Training and Testing Step

In [None]:
import os
from tqdm import tqdm
import random
import multiprocessing
from joblib import Parallel, delayed
import pickle
from shutil import copyfile,rmtree
import torch
import torch.nn.parallel
import torch.utils.data
import numpy as np

num_cores = multiprocessing.cpu_count()

# Root directory 
root = "/storage/intra/"
path_normal = root+'normal/'
path_abnormal = root+'abnormal/'

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Datasets
normal_ecg = Ecgdataset(path_normal, upload_on_ram = True)
normal_ecg = normal_ecg.get_db()

abnormal_ecg = Ecgdataset(path_abnormal, upload_on_ram = True)
abnormal_ecg = abnormal_ecg.get_db()

# Batch size during training
batch_size = 256

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Number of training epochs
num_epochs = 100

# Decide which device we want to run on
if ngpu >0:
    device = torch.device("cuda:0" if ( ngpu > 0 and torch.cuda.is_available()) else "cpu")
else:
    device = torch.device("cpu")
print(device)

## 5-fold cross validation. The results for each for are automatically saved in a specified folder

In [None]:
model = CBeatAAE(device= device, nz = 5)
model_name = 'AAE'

# If the first i fold has already been done
start = 0
n_fold = 5
normal_valid_perc = 0.05
abnormal_valid_perc = 0.1

for i in range(start, n_fold):
    train_ecg = normal_ecg.copy()
    
    # Test ECG
    test_ecg = abnormal_ecg.copy()
    # select the normal to put in the test
    test_ecg.extend(train_ecg[int(i/n_fold*len(normal_ecg)): int((i+1)/n_fold*len(normal_ecg))])
    train_ecg[int(i/n_fold*len(normal_ecg)): int((i+1)/n_fold*len(normal_ecg))] = []
    
    # Validation ECG 
    valid_ecg = train_ecg[:int(normal_valid_perc*len(train_ecg))]
    train_ecg[:int(normal_valid_perc*len(train_ecg))] = []
    
    valid_ecg.extend(test_ecg[:int(abnormal_valid_perc*len(abnormal_ecg))])
    test_ecg[:int(abnormal_valid_perc*len(abnormal_ecg))] = []
    
    # Create the datasets
    train_set = Ecgdataset(None, db = train_ecg) 
    validation_set = Ecgdataset(None, db = valid_ecg)
    test_set = Ecgdataset(None, db = test_ecg)

    # Create the dataloader
    train_dl = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                         shuffle=True, drop_last=True)
    valid_dl = torch.utils.data.DataLoader(validation_set, batch_size=batch_size*10)
    test_dl = torch.utils.data.DataLoader(test_set, batch_size = batch_size*10)
    
    model.initialize()
    model.train(train_dl, valid_dl, num_epochs, lambda_rec = 1, lambda_adv = 1,lambda_tv = 0.001,beta1 = 0, model_folder = '/storage/models/')
    
    if not os.path.exists("/storage/models/"+model_name+"/fold_"+str(i)+"/"):
        os.makedirs("/storage/models/"+model_name+"/fold_"+str(i)+"/")
    
    model.test(test_dl, result_folder="/storage/models/"+model_name+"/fold_"+str(i)+"/")
    
    del train_ecg
    del valid_ecg
    del test_ecg