# Imports

In [1]:
import random
import sys

import importlib
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, metrics

import data_visualisation as dv
import data_augmentation as da
import models.resnet as resnet

2024-05-25 17:07:35.580271: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-25 17:07:36.049201: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-05-25 17:07:36.049260: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


# Load Dataset

In [2]:
# load data
thismodule = sys.modules[__name__]

with np.load('data/PTB_XL_HB_1s_window.npz', allow_pickle=True) as data:
    for k in data.keys():
        if 'text' in k:
            setattr(thismodule, k, data[k])
        else:
            setattr(thismodule, k, data[k].astype(float))

# Data Augmentation + Pair Generation

In [3]:
def mask_ecg(ecg, mask_ratio=0.1):
    block_size = 20
    
    for lead in range(ecg.shape[1]):
        for i in range(0, ecg.shape[0], block_size):
            if random.random() < mask_ratio:
                ecg[i:i+block_size, lead] = 0
    return ecg

def mask_lead(ecg, mask_ratio=0.1):

    for lead in range(ecg.shape[1]):
        if random.random() < mask_ratio:
            ecg[:, lead] = 0
    return ecg

In [4]:
def augment_ecg_signal(signal):
    drifted_signal, _ = da.add_random_baseline_drift(signal, strength_range=(1.5,2.5), drift_wavelength_range=(300,500))
    noised_drifted_signal = da.add_random_noise(drifted_signal, (0, 0.2))
    # res = mask_ecg(noised_drifted_signal, mask_ratio=0.1) if random.random() < 0.5 else mask_lead(noised_drifted_signal, mask_ratio=0.2)
    res = noised_drifted_signal
    return res

def generate_augmented_pairs(batch):
    augmented_batch_1 = np.array([augment_ecg_signal(x) for x in batch])
    augmented_batch_2 = np.array([augment_ecg_signal(x) for x in batch])
    return augmented_batch_1, augmented_batch_2

def augment_ecg_signal_batch(signals, labels, batch_size):
    while True:
        
        indices = np.random.randint(0, signals.shape[0], size=batch_size)
        batch = signals[indices]
        batch_labels = labels[indices]
        
        augmented_batch = np.array([augment_ecg_signal(sample) for sample in batch])
        
        yield (augmented_batch, batch_labels)

# NT-Xent Loss

In [5]:
def nt_xent_loss(x, temperature):
    # cosine similarity
    x_norm = tf.math.l2_normalize(x, axis=1)
    x_cos_sim = tf.matmul(x_norm, x_norm, transpose_b=True)
    
    # mask with -inf on diagonal
    mask = tf.eye(x.shape[0], dtype=tf.bool)
    x_cos_sim = tf.where(mask, -np.inf, x_cos_sim)  
    
    # set targets
    samples = x.shape[0]
    target = tf.range(samples)
    target = tf.concat([target[samples//2:], target[:samples//2]], axis=0)
    # target = tf.range(samples)
    # target = tf.where(target % 2 == 0, target + 1, target - 1)
    
    # temperature scaling
    x_cos_sim /= temperature
    
    # loss
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=x_cos_sim, labels=target)
    loss = tf.reduce_mean(loss)
    
    return loss

# MoCo class

In [13]:
# Define the MoCo Queue
class MoCo:
    def __init__(self, input_shape, projection_dim, queue_size=65536, temperature=0.1):
        self.projection_dim = projection_dim
        self.queue_size = queue_size
        self.temperature = temperature
        self.queue_ptr = 0

        # Initialize the feature extractor and projection head
        self.model = self.create_model(input_shape)

        # Initialize the queue
        self.queue = tf.random.normal([queue_size, projection_dim])

    def enqueue(self, features):
        batch_size = features.shape[0]
        replace_indices = tf.range(self.queue_ptr, self.queue_ptr + batch_size) % self.queue_size
        self.queue = tf.tensor_scatter_nd_update(self.queue, tf.expand_dims(replace_indices, 1), features)
        self.queue_ptr = (self.queue_ptr + batch_size) % self.queue_size

    def dequeue(self):
        return self.queue
    
    def _dequeue_and_enqueue(self, keys):
        keys = concat_all_gather(keys)
        batch_size = tf.shape(keys)[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size].assign(tf.transpose(keys))
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr.assign(ptr)
    
    def projection_head(self, X):
        X = keras.layers.Dense(self.projection_dim, activation='relu', name='dense_proj_1')(X)
        X = keras.layers.Dense(self.projection_dim, name='dense_proj_2')(X)
        return X
    
    def create_model(self, input_shape):
        X_input = keras.Input(input_shape)
        _, fe = resnet.model(X_input, num_classes=5, filters = [16, 16], kernels = [5, 3], layers=10, hidden_units=128)
        out = self.projection_head(fe)
        return keras.Model(inputs=X_input, outputs=out)
    
    def nt_xent_loss(self, queries, keys):
        logits = tf.matmul(queries, keys, transpose_b=True) / self.temperature
        batch_size = tf.shape(logits)[0]
        labels = tf.range(batch_size)
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        return tf.reduce_mean(loss)

    @tf.function
    def training_step(self, batch, optimizer):
        with tf.GradientTape() as tape:
            queries = self.model(batch[0], training=True)
            keys = self.model(batch[1], training=True)

            # Normalize
            queries = tf.math.l2_normalize(queries, axis=1)
            keys = tf.math.l2_normalize(keys, axis=1)

            # Get queue keys
            queue_keys = self.dequeue()

            # Compute loss
            all_keys = tf.concat([keys, queue_keys], axis=0)
            loss = self.nt_xent_loss(queries, all_keys)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        # Update queue
        self.enqueue(keys)

        return loss
    
    def train(self, dataset, optimizer, epochs, steps_per_epoch):
        for epoch in range(epochs):
            for step, batch in enumerate(dataset):
                loss = self.training_step(batch, optimizer)
                if step % 100 == 0:
                    print(f"Epoch {epoch+1}, Step {step}/{steps_per_epoch}, Loss: {loss.numpy()}")

# Training

In [8]:
# ecg_dataset = tf.data.Dataset.from_tensor_slices(X_train)
# ecg_dataset = ecg_dataset.shuffle(buffer_size=1024).batch(32)

dataset = tf.data.Dataset.from_tensor_slices(X_train).batch(32)
augmented_dataset = dataset.map(lambda x: tf.py_function(generate_augmented_pairs, [x], [tf.float32, tf.float32])).prefetch(tf.data.AUTOTUNE)

In [9]:
input_shape = X_train.shape[1:]
projection_dim = 128
temperature = 0.1
queue_size = 65536
batch_size = 32
epochs = 10
steps_per_epoch = 1000

In [14]:
moco = MoCo(input_shape, projection_dim, queue_size, temperature)


In [15]:
optim = keras.optimizers.Adam(learning_rate=0.0001)
moco.train(augmented_dataset, optim, epochs, steps_per_epoch)

TypeError: <tf.Tensor 'TensorScatterUpdate:0' shape=(65536, 128) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'TensorScatterUpdate:0' shape=(65536, 128) dtype=float32> was defined here:
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
      app.start()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
      self.io_loop.start()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
      self.asyncio_loop.run_forever()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
      await result
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 359, in execute_request
      await super().execute_request(stream, ident, parent)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 446, in do_execute
      res = shell.run_cell(
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
      result = self._run_cell(
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
      result = runner(coro)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_22211/4228252922.py", line 2, in <module>
      moco.train(augmented_dataset, optim, epochs, steps_per_epoch)
    File "/tmp/ipykernel_22211/1281425041.py", line 71, in train
      loss = self.training_step(batch, optimizer)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 880, in __call__
      result = self._call(*args, **kwds)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 928, in _call
      self._initialize(args, kwds, add_initializers_to=initializers)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 749, in _initialize
      self._variable_creation_fn    # pylint: disable=protected-access
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 162, in _get_concrete_function_internal_garbage_collected
      concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function
      return self._maybe_define_function(args, kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function
      concrete_function = self._create_concrete_function(args, kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function
      func_graph_module.func_graph_from_py_func(
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func
      func_outputs = python_func(*func_args, **func_kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 645, in wrapped_fn
      out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 445, in bound_method_wrapper
      return wrapped_fn(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1258, in autograph_handler
      return autograph.converted_call(
    File "/tmp/ipykernel_22211/1281425041.py", line 64, in training_step
      self.enqueue(keys)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 880, in __call__
      result = self._call(*args, **kwds)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 928, in _call
      self._initialize(args, kwds, add_initializers_to=initializers)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 749, in _initialize
      self._variable_creation_fn    # pylint: disable=protected-access
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 162, in _get_concrete_function_internal_garbage_collected
      concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function
      return self._maybe_define_function(args, kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function
      concrete_function = self._create_concrete_function(args, kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function
      func_graph_module.func_graph_from_py_func(
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func
      func_outputs = python_func(*func_args, **func_kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 645, in wrapped_fn
      out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 445, in bound_method_wrapper
      return wrapped_fn(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1258, in autograph_handler
      return autograph.converted_call(
    File "/tmp/ipykernel_22211/1281425041.py", line 19, in enqueue
      self.queue = tf.tensor_scatter_nd_update(self.queue, tf.expand_dims(replace_indices, 1), features)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1176, in op_dispatch_handler
      return dispatch_target(*args, **kwargs)
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py", line 6100, in tensor_scatter_nd_update
      return gen_array_ops.tensor_scatter_update(
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/ops/gen_array_ops.py", line 11527, in tensor_scatter_update
      _, _, _op, _outputs = _op_def_library._apply_op_helper(
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 795, in _apply_op_helper
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 749, in _create_op_internal
      return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    File "/home/raaif/anaconda3/envs/FYP/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 3798, in _create_op_internal
      ret = Operation(

The tensor <tf.Tensor 'TensorScatterUpdate:0' shape=(65536, 128) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=enqueue, id=128257392638992), which is out of scope.