<a href="https://colab.research.google.com/github/utsb-fmm/FHRMA/blob/master/FS%20training%20python%20sources/2_Training_FSMHR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Train FSMHR model

Associated paper in https://www.preprints.org/manuscript/202207.0131/v1

FHR Morphological Analysis Toolbox Copyright (C) 2022 Samuel Boudet, Faculté de Médecine et Maïeutique, samuel.boudet@gmail.com

This file is part of FHR Morphological Analysis Toolbox

FHR Morphological Analysis Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

FHR Morphological Analysis Toolbox is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/.

In [1]:
import tensorflow as tf
import array
import numpy as np
import os
import time
import datetime
import pickle
from IPython.core.debugger import set_trace
from tensorflow.keras import layers as lay
from tensorflow.keras import regularizers


In [3]:
Mname='FSMHR00'

grufn=lay.GRU


from google.colab import drive
drive.mount('drive')
basefolder="drive/My Drive/FHRMA-Training-FS/"
logfolder="logs/" # Tensorboard logs are saved only locally (change to drive if you want to save them)

Mounted at drive


In [4]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)


strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

In [5]:
with open(basefolder+'dataV8MHR13-13.pkl','rb') as f:
    X_train, Y_train, X_val, Y_val, ListRec,LengthRec = pickle.load(f)

samp=X_train.shape[1]
batch_size=X_val.shape[0]


In [6]:
@tf.function 
def weighted_binary_crossentropy(yTrue, yPred): 
    Pred=yPred[:,:,0]
    Pred=tf.clip_by_value(Pred,tf.constant(.00001),tf.constant(.999999))
    isFalse=yTrue[:,:,3]
    docare=yTrue[:,:,2]    
    N=(tf.keras.backend.log(Pred)*isFalse+tf.keras.backend.log(tf.constant(1.)-Pred)*(tf.constant(1.)-isFalse))*docare
    return -tf.keras.backend.sum(N)#/tf.keras.backend.sum(docare)

@tf.function 
def weighted_accuracy(yTrue, yPred):
    
    Pred=yPred[:,:,0]
    isFalse=yTrue[:,:,3]
    docare=yTrue[:,:,2] 
    C=tf.cast((Pred>tf.constant(.5)),tf.float32)    
    N=tf.keras.backend.sum( tf.cast((C==isFalse),tf.float32) * docare)
    return tf.keras.backend.sum(N)#/tf.keras.backend.sum(docare)



In [7]:
class Split(tf.keras.layers.Layer):
    def __init__(self):
        super(Split, self).__init__()      

    def call(self, inputs):
        return tf.split(inputs,num_or_size_splits=2, axis=0)

class RevT(tf.keras.layers.Layer):
    def __init__(self):
        super(RevT, self).__init__()      

    def call(self, inputs):
        return tf.keras.backend.reverse(inputs,1)



In [8]:
L=(os.path.getsize(basefolder+"DiffMF.dat")//2)
i16=array.array("h")
with open(basefolder+'DiffMF.dat', 'rb') as f:  
    i16.fromfile(f,L)
DiffFM=np.asarray(i16.tolist()).reshape(-1,1)/4/60
DiffFM=np.concatenate((DiffFM,np.flip(DiffFM,0)),0)
DiffFM=tf.constant(DiffFM[0:samp,:],dtype=tf.float32)
def siglossgenerator(X,Y):
    pMHRn=np.array([4,12,24,36,64,120])*2
    pMHRp=np.array([.35,.35,.35,.35,.35,.35])
    pMHRd=np.array([3000,1000,300,60,20,6])
    pFHRn=np.array([4,12,24,36,64,120])*2
    pFHRp=np.array([.25,.25,.25,.25,.25,.25])
    pFHRd=np.array([2000,500,150,60,20,6])
    
    

    pKeepMHR=.15
    pKeepFHR=.20
    pCut=.4
    pNoStage2=.1

    pMHR0=tf.constant(.1,dtype=tf.float32)
    stdMult=tf.constant(.08,dtype=tf.float32)

   
    rg=tf.reshape(tf.range(0,samp,dtype=tf.int32),[samp,1])
    if tf.random.uniform([])<pCut:
        pos=tf.random.uniform([],maxval=samp-1,dtype=tf.int32)
        newX=tf.tensor_scatter_nd_update(X, [[pos,0],[pos,1],[pos,2],[pos,3],[pos,4],[pos,5]], [0,0,0,0,0,1])
        newY=tf.tensor_scatter_nd_update(Y, [[pos,0],[pos,1],[pos,2],[pos,3]], [0,0,0,0])
    else:
        newX=X
        newY=Y  

    FHR=newX[:,0:1]
    maskFHR=newX[:,1:2]
    MHR=newX[:,2:3]
    maskMHR=newX[:,3:4]
    isStage2=newX[:,4:5]
    care=newY[:,0:1]
    falsesig=newY[:,1:2]
    careM=newY[:,2:3]
    falsesigM=newY[:,3:4]
    
    if tf.random.uniform([])<pNoStage2:
        isStage2=tf.zeros(tf.shape(isStage2),dtype=tf.dtypes.float32)
            
  
    if tf.random.uniform([])>pKeepMHR:
        for k in range(len(pMHRn)):
            N=tf.cast(tf.floor(tf.math.abs(tf.random.normal([],0,pMHRn[k],dtype=tf.float32))),tf.int32)
            for j in range(N):
                if tf.random.uniform([])<pMHRp[k]: # removing maternal heart rate
                    mid=tf.random.uniform([],maxval=samp,dtype=tf.int32)
                    dur=tf.cast(abs(tf.random.normal([],mean=pMHRd[k],stddev=pMHRd[k]/2)),tf.int32)
                    maskMHR=tf.where( (rg<mid-dur)|(rg>=mid+dur), maskMHR, [0])  
                    careM=tf.where( (rg<mid-dur)|(rg>=mid+dur), careM, [0])    
        
        N=tf.cast(tf.floor(tf.math.abs(tf.random.normal([],0,50,dtype=tf.float32))),tf.int32)
        for j in range(N):
            mid=tf.random.uniform([],maxval=samp,dtype=tf.int32)
            durT=tf.cast(abs(tf.random.normal([],mean=0,stddev=160/2)),tf.int32)
            durPart=tf.cast(abs(tf.random.normal([],mean=0,stddev=80/2)),tf.int32)
            OffsetPart=tf.cast(tf.random.normal([],mean=0,stddev=20),tf.int32)
            endFalse=tf.clip_by_value(durPart+OffsetPart,-durT,durT);
            startFalse=tf.clip_by_value(-durPart+OffsetPart,-durT,durT);

            
            r=tf.random.uniform([])
            if r<.8:
                change=tf.where( (rg>=mid+startFalse) & (rg<mid+endFalse) & (maskMHR==1.) & (MHR<0.1), [1.], [0.] )
                MHR=(MHR+2)*(1+change)-2
            else:
                change=tf.where( (rg>=mid+startFalse) & (rg<mid+endFalse) & (maskMHR==1.) & (MHR>-1.3), [1.], [0.] )
                MHR=(MHR+2)*(1-.5*change)-2

            falsesigM=tf.where( (change==1) & (falsesigM==0) & (careM>0) ,[1.],falsesigM )
            maskMHR=tf.where( (rg<mid-durT)|(rg>=mid+durT)|(change==1.), maskMHR, [0])
            careM=tf.where( (rg<mid-durT)|(rg>=mid+durT)|(change==1.), careM*(1+change), [0]) 

    if tf.random.uniform([])>pKeepFHR:
        for k in range(len(pFHRn)):
            N=tf.cast(tf.floor(tf.math.abs(tf.random.normal([],0,pFHRn[k],dtype=tf.float32))),tf.int32)
            for j in range(N):
                if tf.random.uniform([])<pFHRp[k]: # removing maternal heart rate
                    mid=tf.random.uniform([],maxval=samp,dtype=tf.int32)
                    dur=tf.cast(abs(tf.random.normal([],mean=pFHRd[k],stddev=pFHRd[k]/2)),tf.int32)
                    maskFHR=tf.where( (rg<mid-dur)|(rg>=mid+dur), maskFHR, [0]) 
                    care=tf.where( (rg<mid-dur)|(rg>=mid+dur), care, [0]) 

        N=tf.cast(tf.floor(tf.math.abs(tf.random.normal([],0,100,dtype=tf.float32))),tf.int32)
        for j in range(N):
            mid=tf.random.uniform([],maxval=samp,dtype=tf.int32)
            durT=tf.cast(abs(tf.random.normal([],mean=120,stddev=120/2)),tf.int32)
            durPart=tf.cast(abs(tf.random.normal([],mean=60,stddev=60/2)),tf.int32)
            OffsetPart=tf.cast(tf.random.normal([],mean=0,stddev=20),tf.int32)
            endFalse=tf.clip_by_value(durPart+OffsetPart,-durT,durT);
            startFalse=tf.clip_by_value(-durPart+OffsetPart,-durT,durT);
            r=tf.random.uniform([])
            if r<.5:
                change=tf.where( (rg>=mid+startFalse) & (rg<mid+endFalse) & (falsesig==0.)& (care>0.), [1.], [0.] )
            else:
                change=tf.where( (rg>=mid+startFalse) & (rg<mid+endFalse) & (falsesig==0.)& (care>0.) & (maskMHR==1.), [1.], [0.] )

            falsesig=falsesig+change
            maskFHR=tf.where( (rg<mid-durT)|(rg>=mid+durT)|(change==1.), maskFHR, [0])
            care=tf.where( (rg<mid-durT)|(rg>=mid+durT)|(change==1.), care*(1+change), [0]) 
            if r<.3:
                FHR=(FHR+2)*(1+change)-2
            elif r<.5:
                FHR=(FHR+2)*(1-.5*change)-2
            elif r<.9:
                FHR=(FHR+2)*(1-change)+(MHR+2+DiffFM)*change-2
            elif r<.95:
                FHR=(FHR+2)*(1-change)+(MHR+2+DiffFM)*2*change-2
            else:
                FHR=(FHR+2)*(1-change)+(MHR+2+DiffFM)*.5*change-2

            if r>.5 and tf.random.uniform([])<.5:
                durMHRLost=tf.cast(abs(tf.random.normal([],mean=0,stddev=240)),tf.int32)
                OffsetMHRLost=tf.cast(tf.random.normal([],mean=0,stddev=120),tf.int32)
                maskMHR=tf.where( (rg<mid+OffsetMHRLost-durMHRLost)|(rg>=mid+OffsetMHRLost+durMHRLost), maskMHR, [0])   
                careM=tf.where( (rg<mid+OffsetMHRLost-durMHRLost)|(rg>=mid+OffsetMHRLost+durMHRLost), careM, [0])   

    maskFHR=tf.where( FHR<=2.25, maskFHR, [0])   
    if tf.random.uniform([])<pMHR0: 
        maskMHR=tf.zeros([samp,1],dtype=tf.float32)
        careM=tf.zeros([samp,1],dtype=tf.float32)    
    
    S=1.+tf.random.normal([],stddev=stdMult)

    newX=tf.concat([ ((FHR+2)*S-2)*maskFHR , maskFHR, ((MHR+2)*S-2)*maskMHR,maskMHR,isStage2,newX[:,5:6] ],axis=1)
    
    newX=tf.ensure_shape(newX,X.shape)
    newY=tf.concat([ care,falsesig,careM,falsesigM ],axis=1)
    return newX,newY


traindataset=tf.data.Dataset.from_tensor_slices(    (  tf.constant(X_train,dtype=tf.float32)  ,  tf.constant(Y_train,dtype=tf.float32)  )    )
traindataset=traindataset.shuffle(X_train.shape[0],reshuffle_each_iteration=True).repeat()
traindataset=traindataset.map(siglossgenerator)
traindataset=traindataset.batch(batch_size,drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

valdataset=tf.data.Dataset.from_tensor_slices((X_val,Y_val)).cache().repeat().batch(batch_size,drop_remainder=True)


In [9]:
class Sparsity(tf.keras.constraints.Constraint):
    def __init__(self, n,h):
        self. n = n    
        self.h=h
        m=np.zeros([h,h])
        for i in range(int(h/n)):
            m[i*n:i*n+n,i*n:i*n+n]=1.
        self.mask=tf.constant(np.concatenate([m,m,m],axis=1),dtype=tf.float32)
        
    def __call__(self, w):
        w.assign(w*self.mask)
        return w

class SparsityKernel(tf.keras.constraints.Constraint):
    def __init__(self, m,n,slices):
        #m input size ; n state size
        #slices
        ma1=np.zeros([m,n],dtype=np.float32)
        for i in range(slices.shape[0]):
            ma1[slices[i,0]:slices[i,1],slices[i,2]:slices[i,3]]=1.
        self.mask=tf.constant(np.concatenate([ma1,ma1,ma1],axis=1),dtype=tf.float32)
        ma2=np.zeros([m,3*n])
        ma2[m-1,n:2*n]=-1.e30
        ma2[m-1,0:n]=-1.e30
        self.maskReset=tf.constant(ma2,dtype=tf.float32)
    def __call__(self, w):
        w.assign(w*self.mask+self.maskReset)
        return w

class Reseter(tf.keras.constraints.Constraint):
    def __init__(self, m,n):
        ma=np.zeros([m,3*n])
        ma[m-1,n:2*n]=-1.e30
        ma[m-1,0:n]=-1.e30
        self.mask=tf.constant(ma,dtype=tf.float32)
        
    def __call__(self, w):
        w.assign(w+self.mask)
        return w

class BiasConstraintCallback(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        for GRUlayer in ["GRU1MHR"]:
            W=self.model.get_layer(GRUlayer).get_weights()
            L=int(len(W[0][-1])/3)
            W[0][-1][2*L:3*L]=-W[2][0][2*L:3*L]#-W[2][1][2*L:3*L] #Pour CudNN a priori W[2][2*L:3*L]-W[2][5*L:6*L] Mais à vérifier ordre des 3 couches
            self.model.get_layer(GRUlayer).set_weights(W)
            

In [10]:

def create_model(batch_size,siglen):
    I=lay.Input(batch_shape= (batch_size, siglen, 6), name="SigInput") 

    RevI=RevT()(I)
    IforwBack=lay.concatenate([I,RevI], axis=0)
    slicesKernel0=np.array([[2,6,0,12]])

    GRU1M=grufn(12,return_sequences=True,recurrent_initializer='glorot_uniform'
        ,stateful=False,recurrent_constraint=Sparsity(4,12),kernel_constraint=SparsityKernel(6,12,slicesKernel0),name='GRU1MHR')(IforwBack)
    GRU1M=lay.GaussianDropout(.2)(GRU1M)

    Lay0Forw,Lay0Backw=Split()(GRU1M)
    PMat=lay.Dense(1,activation='sigmoid',name="DensePmat")(lay.concatenate([Lay0Forw,RevT()(Lay0Backw)], axis=2))

    model = tf.keras.Model(inputs=[I], outputs=[PMat])


    #model.summary()
    return model

In [11]:
callbacks=[tf.keras.callbacks.TensorBoard(log_dir = logfolder+Mname , histogram_freq = 0)]
callbacks.append(tf.keras.callbacks.ModelCheckpoint(filepath=basefolder+Mname+"/best-{epoch:05d}-{val_loss:.4f}.h5",
                                                 save_best_only=True,
                                                 verbose=1))
callbacks.append(tf.keras.callbacks.ModelCheckpoint(filepath=basefolder+Mname+"/save-{epoch:05d}.h5",
                                                 save_freq=200,
                                                 verbose=1))
def scheduler(epoch):
    return 2.e-3 * tf.math.exp(-.005*tf.math.log(10.)*epoch)+1.e-3
    
callbacks.append(tf.keras.callbacks.LearningRateScheduler(scheduler))
callbacks.append(BiasConstraintCallback())

In [13]:
initial_epoch=0
with strategy.scope():
    model=create_model(int(batch_size/strategy.num_replicas_in_sync),samp)
    #model.load_weights(basefolder+Mname+"/save-00000.h5")
    model.compile(
        optimizer=tf.optimizers.Adam(learning_rate=1e-3),
        loss=weighted_binary_crossentropy,
        metrics=[weighted_accuracy])

In [None]:
#One epoch to create logs folders for tensorboard
model.fit(traindataset,
    epochs=1,
    steps_per_epoch=np.floor(X_train.shape[0]/batch_size),
    validation_freq=1,
    initial_epoch=initial_epoch,
    validation_steps=1,
    validation_data=valdataset,
    callbacks=callbacks 
)

%load_ext tensorboard
%tensorboard --logdir logs/

In [None]:
model.fit(traindataset,
    epochs=25000,
    steps_per_epoch=np.floor(X_train.shape[0]/batch_size),
    validation_freq=1,
    initial_epoch=initial_epoch+1,
    validation_steps=1,
    validation_data=valdataset,
    callbacks=callbacks 
)
#Do not worry if val_loss is bad (0.5) on 500 firsts epochs it goes down after

In [None]:
#If you want to load and test the trained model

!wget https://github.com/utsb-fmm/FHRMA/raw/master/FS%20training%20python%20sources/FSMHR.h5
model.load_weights("FSMHR.h5")
model.evaluate(X_val,Y_val)