In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install tensorflow-addons
!pip install toolz scikit-allel

Collecting tensorflow-addons
  Downloading tensorflow_addons-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (612 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m612.1/612.1 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard<3.0.0,>=2.7 (from tensorflow-addons)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, tensorflow-addons
Successfully installed tensorflow-addons-0.21.0 typeguard-2.13.3
Collecting scikit-allel
  Downloading scikit_allel-1.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.1/8.1 MB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-allel
Successfully installed scikit-allel-1.3.7


In [3]:
import numpy as np
%tensorflow_version 2.x
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.
Tensorflow version 2.12.0


## Setup

In [4]:
import os
# os.environ["MODIN_CPUS"] = "8"
# from distributed import Client
# client = Client()
import numpy as np
import math
import re
import random
import shutil
import gzip
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow_addons as tfa
from sklearn import metrics
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras.applications import efficientnet as efn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from tensorflow.keras.constraints import Constraint
# import allel
from scipy.spatial.distance import squareform
%matplotlib inline
from toolz import interleave
from tqdm import tqdm
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LassoCV, ElasticNetCV
from sklearn.model_selection import KFold,StratifiedKFold

print("Tensorflow version " + tf.__version__)


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



Tensorflow version 2.12.0


## Hardware Config

In [5]:
# Detect hardware, return appropriate distribution strategy
try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', TPU.master())
except ValueError:
    print('Running on GPU')
    TPU = None

if TPU:
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.TPUStrategy(TPU)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

N_REPLICAS = strategy.num_replicas_in_sync
# Number of computing cores, is 8 for a TPU V3-8
print(f'N_REPLICAS: {N_REPLICAS}')

Running on TPU  grpc://10.28.108.50:8470
N_REPLICAS: 8


## Prepare the data

In [6]:
class DataLoader:
    """
    If the reference is unphased, cannot handle phased target data, so the valid (ref, target) combinations are:
    (phased, phased), (phased, unphased), (unphased, unphased)
    Important note: for each case, the model should be trained separately
    """
    def __init__(self, reference_panel_file_path, target_file_path):
        self.ref_n_header_lines = []
        self.ref_n_data_header = ""
        self.map_values_1_vec = np.vectorize(self.map_hap_2_ind_parent_1)
        self.map_values_2_vec = np.vectorize(self.map_hap_2_ind_parent_2)
        print("Rading the reference file...")
        # get header
        root, ext = os.path.splitext(reference_panel_file_path)
        with gzip.open(reference_panel_file_path, 'rt') if ext == '.gz' else open(reference_panel_file_path, 'rt') as f_in:
            # skip info
            while True:
                line = f_in.readline()
                if line.startswith("##"):
                    self.ref_n_header_lines.append(line)
                else:
                    self.ref_n_data_header = line
                    break
        self.reference_panel = pd.read_csv(reference_panel_file_path,
                                           comment='#',
                                           sep='\t',
                                           names=self.ref_n_data_header.strip().split('\t'))
        self.VARIANT_COUNT = self.reference_panel.shape[0]
        print(f"{self.VARIANT_COUNT} variants found. Done!")
        print("Rading the target file...")
        self.target_n_header_lines = []
        self.target_n_data_header = ""
        root, ext = os.path.splitext(target_file_path)
        # get header
        with gzip.open(target_file_path, 'rt') if ext == '.gz' else open(target_file_path, 'rt') as f_in:
            # skip info
            while True:
                line = f_in.readline()
                if line.startswith("##"):
                    self.target_n_header_lines.append(line)
                else:
                    self.target_n_data_header = line
                    break
        real_target_set = pd.read_csv(target_file_path,
                                           comment='#',
                                           sep='\t',
                                           names=self.target_n_data_header.strip().split('\t'),)
        print(f"{real_target_set.shape[0]} variants found. Done!")
        target_is_phased = "|" in real_target_set.iloc[0, 10]
        ref_is_phased = "|" in self.reference_panel.iloc[0, 10]
        self.is_phased = target_is_phased and ref_is_phased
        print("Creating the new target dataframe")
        self.target_set = real_target_set.merge(self.reference_panel["ID"], on='ID', how='right')
        self.target_set[self.reference_panel.columns[:9]] = self.reference_panel[self.reference_panel.columns[:9]]
        self.target_set.fillna(".|." if self.is_phased else "./.", inplace=True)
        print("Extracting genotype information...")
        SEP = "|" if self.is_phased else "/"
        def get_num_allels(g):
            v1, v2 = g.split(SEP)
            return max(int(v1), int(v2)) + 1

        def key_gen(v1, v2):
            return f"{v1}{SEP}{v2}"

        genotype_vals = np.unique(self.reference_panel.iloc[:, 9:].values)
        if target_is_phased != ref_is_phased:
            phased_to_unphased_dict = {}
            for i in range(genotype_vals.shape[0]):
                key = genotype_vals[i]
                v1, v2 = [int(s) for s in genotype_vals[i].split("|")]
                genotype_vals[i] = f"{min(v1, v2)}{SEP}{max(v1, v2)}"
                phased_to_unphased_dict[key] = genotype_vals[i]
            self.reference_panel.replace(phased_to_unphased_dict, inplace=True)
        genotype_vals = np.unique(genotype_vals)
        allele_count = max(map(get_num_allels, genotype_vals))
        if self.is_phased:
            self.hap_map = {str(i): i for i in range(allele_count)}
            self.hap_map.update({".": allele_count})
            self.r_hap_map = {i:k for k, i in self.hap_map.items()}
            self.map_preds_2_allele = np.vectorize(lambda x: self.r_hap_map[x])
        self.MISSING_VALUE = self.SEQ_DEPTH = allele_count + 1 if self.is_phased else len(genotype_vals) + 1
        self.genotype_keys = np.array([key_gen(i,j) for i in range(allele_count) for j in range(allele_count)]) if self.is_phased else genotype_vals
        self.genotype_keys = np.hstack([self.genotype_keys, [".|."] if self.is_phased else ["./."]])
        self.replacement_dict = {g:i for i,g in enumerate(self.genotype_keys)}
        self.reverse_replacement_dict = {i:g for g,i in self.replacement_dict.items()}

    def get_max_gap_len(self):
      self.target_set_vals = self.target_set.iloc[:, 9].values
      gap_meter = 0
      max_gap_so_far = 0
      for i, val in enumerate(self.target_set_vals):
          if val in (".|.", "./."):
              gap_meter+=1
          else:
              max_gap_so_far = max(gap_meter, max_gap_so_far)
              gap_meter = 0
      max_gap_so_far = max(gap_meter, max_gap_so_far)
      return max_gap_so_far

    def map_hap_2_ind_parent_1(self, x):
        return self.hap_map[x.split('|')[0]]

    def map_hap_2_ind_parent_2(self, x):
        return self.hap_map[x.split('|')[1]]

    def __get_forward_data(self, data: pd.DataFrame):
        if self.is_phased:
            # break it into haplotypes
            _x = np.empty((data.shape[1] * 2, data.shape[0]), dtype=np.int32)

            _x[0::2] = self.map_values_1_vec(data.values.T)
            _x[1::2] = self.map_values_2_vec(data.values.T)
            return _x
        else:
            return data.replace(self.replacement_dict).values.T.astype(np.int32)

    def get_ref_set(self, starting_var_index=None, ending_var_index=None):
        if starting_var_index>=0 and ending_var_index>=starting_var_index:
            return self.__get_forward_data(self.reference_panel.iloc[starting_var_index:ending_var_index, 9:])
        else:
            print("No variant indices provided or indices not valid, using the whole sequence...")
            return self.__get_forward_data(self.reference_panel.iloc[:, 9:])

    def get_target_set(self, starting_var_index=None, ending_var_index=None):
        if starting_var_index>=0 and ending_var_index>=starting_var_index:
            return self.__get_forward_data(self.target_set.iloc[starting_var_index:ending_var_index, 9:])
        else:
            print("No variant indices provided or indices not valid, using the whole sequence...")
            return self.__get_forward_data(self.target_set.iloc[:, 9:])


In [7]:
root_data_dir = '[data_path]'
train_file_name = "beadchip_reference_all_minaf_05_snps_hwe_1e-2_filtered_train.vcf.gz"
test_file_name = "test_data_beadchip_hwe_filtered.vcf.gz"

In [8]:
dl = DataLoader(root_data_dir + train_file_name,
                root_data_dir + test_file_name)
max_gap_len = dl.get_max_gap_len()

Rading the reference file...
31143 variants found. Done!
Rading the target file...
7336 variants found. Done!
Creating the new target dataframe
Extracting genotype information...


In [9]:
max_gap_len

310

## Hyperparams

In [10]:
# hyperparameters
inChannel = dl.SEQ_DEPTH
learning_rate = 0.001
weight_decay = 0.00001
dropout_rate = 0.25
attention_range = 100
chunk_size = 2000
max_features_len_per_model = dl.VARIANT_COUNT
inChannel, max_features_len_per_model

(3, 31143)

In [11]:
dl.genotype_keys

array(['0|0', '0|1', '1|0', '1|1', '.|.'], dtype='<U3')

In [12]:
dl.MISSING_VALUE

3

## Convert to tensorflow dataset

In [13]:
@tf.function()
def add_attention_mask(X_sample, y_sample):
  depth = dl.SEQ_DEPTH
  mask_size = tf.cast(X_sample.shape[0]*0.8, dtype=tf.int32)
  mask_idx = tf.reshape(tf.random.shuffle(tf.range(X_sample.shape[0]))[:mask_size], (-1, 1))
  updates = tf.math.add(tf.zeros(shape=(mask_idx.shape[0]), dtype=tf.int32), depth-1)
  X_masked = tf.tensor_scatter_nd_update(X_sample, mask_idx, updates)

  return tf.one_hot(X_masked, depth), tf.one_hot(y_sample, depth-1)

@tf.function()
def onehot_encode(X_sample):
  depth = dl.SEQ_DEPTH
  return tf.one_hot(X_sample, depth)

In [14]:
def get_dataset(x, batch_size, offset_before=0, offset_after=0, training=True):
  AUTO = tf.data.AUTOTUNE

  dataset = tf.data.Dataset.from_tensor_slices((x, x[:, offset_before:x.shape[1]-offset_after]))
  # # Add Attention Mask

  if training:
    dataset = dataset.shuffle(x.shape[0], reshuffle_each_iteration=True)
    dataset = dataset.repeat()

  # Add Attention Mask
  dataset = dataset.map(add_attention_mask, num_parallel_calls=AUTO, deterministic=False)

  # Prefetech to not map the whole dataset
  dataset = dataset.prefetch(AUTO)

  dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_calls=AUTO)

  return dataset

def get_test_dataset(x, batch_size):
  AUTO = tf.data.AUTOTUNE

  dataset = tf.data.Dataset.from_tensor_slices((x))
  # # Add Attention Mask


  # Add Attention Mask
  dataset = dataset.map(onehot_encode, num_parallel_calls=AUTO, deterministic=True)

  # Prefetech to not map the whole dataset
  dataset = dataset.prefetch(AUTO)

  dataset = dataset.batch(batch_size, drop_remainder=False, num_parallel_calls=AUTO)

  return dataset

## Custom Layers


In [15]:
ATTENTION_AXES=(1)

In [16]:
class CrossAttentionLayer(layers.Layer):
  def __init__(self, local_dim, global_dim,
               start_offset=0, end_offset=0,
               activation=tf.nn.gelu, dropout_rate=0.1,
               n_heads=8):
    super(CrossAttentionLayer, self).__init__()
    self.local_dim = local_dim
    self.global_dim = global_dim
    self.dropout_rate = dropout_rate
    self.activation = activation
    self.start_offset = start_offset
    self.end_offset = end_offset
    self.num_heads = n_heads
    self.layer_norm00 = layers.LayerNormalization()
    self.layer_norm01 = layers.LayerNormalization()
    self.layer_norm1 = layers.LayerNormalization()
    self.ffn = tf.keras.Sequential(
          [
            layers.Dense(self.local_dim//2, activation=self.activation,
                        ),
            layers.Dense(self.local_dim,
                        activation=self.activation,
                        ), ]
      )
    self.add0 = layers.Add()
    self.add1 = layers.Add()
    self.attention = layers.MultiHeadAttention(num_heads=self.num_heads,
                                               key_dim=self.local_dim,
                                               attention_axes=ATTENTION_AXES)

  def call(self, inputs, training):
    local_repr = self.layer_norm00(inputs[0])
    global_repr = self.layer_norm01(inputs[1])
    query = local_repr[:, self.start_offset:local_repr.shape[1]-self.end_offset, :]
    key = global_repr
    value = global_repr

    # Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
    attention_output = self.attention(
        query, key, value
    )
    # Skip connection 1.
    attention_output = self.add0([attention_output, query])

    # Apply layer norm.
    attention_output = self.layer_norm1(attention_output)
    # Apply Feedforward network.
    outputs = self.ffn(attention_output)
    # Skip connection 2.
    outputs = self.add1([outputs, attention_output])
    return outputs

class MaskedTransformerBlock(layers.Layer):
  def __init__(self, embed_dim, num_heads, ff_dim, attention_range, start_offset=0, end_offset=0, attn_block_repeats=1, activation=tf.nn.gelu, dropout_rate=0.1, use_ffn=True):
    super(MaskedTransformerBlock, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.ff_dim = ff_dim
    self.start_offset = start_offset
    self.end_offset = end_offset
    self.attention_range = attention_range
    self.attn_block_repeats = attn_block_repeats
    self.activation = activation
    self.dropout_rate = dropout_rate
    self.use_ffn = use_ffn
    self.att0 = [layers.MultiHeadAttention(num_heads=self.num_heads,
                                           key_dim=self.embed_dim,
                                           attention_axes=ATTENTION_AXES) for _ in range(attn_block_repeats)]
    if self.use_ffn:
      self.ffn = [tf.keras.Sequential(
          [
            layers.Dense(self.ff_dim, activation=self.activation,
                        ),
            layers.Dense(self.embed_dim,
                        activation=self.activation,
                        ), ]
      ) for _ in range(attn_block_repeats)]
    self.layer_norm0 = [layers.LayerNormalization() for _ in range(attn_block_repeats)]
    self.layer_norm1 = [layers.LayerNormalization() for _ in range(attn_block_repeats)]

  def build(self, input_shape):
    assert(self.end_offset >= 0)
    self.feature_size = input_shape[1]
    attention_mask = np.zeros((self.feature_size,
                               self.feature_size), dtype=bool)
    for i in range(self.start_offset, self.feature_size - self.end_offset):
      attention_indices = np.arange(max(0, i-self.attention_range), min(self.feature_size, i+self.attention_range))
      attention_mask[i, attention_indices] = True
    self.attention_mask = tf.constant(attention_mask[self.start_offset:self.feature_size-self.end_offset])

  def call(self, inputs, training):

    x = inputs
    for i in range(self.attn_block_repeats-1):
      x = self.layer_norm0[i](x)
      attn_output = self.att0[i](x, x)
      out1 = x + attn_output
      out1 = self.layer_norm1[i](out1)
      if self.use_ffn:
        ffn_output = self.ffn[i](out1)
        x = out1 + ffn_output
      else:
        x = out1

    x = self.layer_norm0[-1](inputs)
    attn_output = self.att0[-1](x[:, self.start_offset:x.shape[1]-self.end_offset, :], x,
                            )
    out1 = x[:, self.start_offset:x.shape[1]-self.end_offset, :] + attn_output
    out1 = self.layer_norm1[-1](out1)
    if self.use_ffn:
      ffn_output = self.ffn[-1](out1)
      x = out1 + ffn_output
    else:
      x = out1
    return x

class GenoEmbeddings(layers.Layer):
  def __init__(self, embedding_dim,
               embeddings_initializer='glorot_uniform',
               embeddings_regularizer=None,
               activity_regularizer=None,
               embeddings_constraint=None):
    super(GenoEmbeddings, self).__init__()
    self.embedding_dim = embedding_dim
    self.embeddings_initializer = initializers.get(embeddings_initializer)
    self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
    self.activity_regularizer = regularizers.get(activity_regularizer)
    self.embeddings_constraint = constraints.get(embeddings_constraint)

  def build(self, input_shape):
    # print(input_shape)

    self.num_of_allels = input_shape[-1]
    self.n_snps = input_shape[-2]
    self.position_embedding = layers.Embedding(
            input_dim=self.n_snps, output_dim=self.embedding_dim
        )
    self.embedding = self.add_weight(
            shape=(self.num_of_allels, self.embedding_dim),
            initializer=self.embeddings_initializer,
            trainable=True, name='geno_embeddings',
            regularizer=self.embeddings_regularizer,
            constraint=self.embeddings_constraint,
            experimental_autocast=False
        )
    self.positions = tf.range(start=0, limit=self.n_snps, delta=1)
  def call(self, inputs):
    self.immediate_result = tf.einsum('ijk,kl->ijl', inputs, self.embedding)
    return self.immediate_result + self.position_embedding(self.positions)


class Chunker(layers.Layer):
  def __init__(self, embed_dim, num_heads, ff_dim, chk_size=chunk_size,
               activation=tf.nn.gelu, dropout_rate=0.25, attn_block_repeats=1,
               attention_range=attention_range, include_embedding_layer=False):
    super(Chunker, self).__init__()
    self.concat = layers.Concatenate(axis=-2)
    self.chunk_size = chk_size
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.ff_dim = ff_dim
    self.activation = activation
    self.dropout_rate = dropout_rate
    self.attention_range = attention_range
    self.attn_block_repeats = attn_block_repeats
    self.include_embedding_layer = include_embedding_layer

  def build(self, input_shape):
    self.chunk_starts = list(range(0, input_shape[1], self.chunk_size))
    self.chunk_ends = []
    for cs in self.chunk_starts:
      self.chunk_ends.append(min(cs+self.chunk_size, input_shape[1]))
    self.mask_starts = [max(0, cs-self.attention_range) for cs in self.chunk_starts]
    self.mask_ends = [min(ce+self.attention_range, input_shape[1]) for ce in self.chunk_ends]
    self.chunkers = [SelfAttnChunk(self.embed_dim, self.num_heads, self.ff_dim,
                           attention_range,
                           include_embedding_layer=self.include_embedding_layer,
                           start_offset=cs - self.mask_starts[i],
                            end_offset=self.mask_ends[i]-self.chunk_ends[i],
                           attn_block_repeats=self.attn_block_repeats) for i, cs in enumerate(self.chunk_starts)]

  def call(self, inputs, training):
    x = inputs
    chunks = [chunker(x[:, self.mask_starts[i]:self.mask_ends[i]]) for i, chunker in enumerate(self.chunkers)]
    y = self.concat(chunks)
    return y


class SelfAttnChunk(layers.Layer):
  def __init__(self, embed_dim, num_heads, ff_dim, attention_range,
               start_offset=0, end_offset=0,
               attn_block_repeats=1,
               include_embedding_layer=False):
    super(SelfAttnChunk, self).__init__()
    self.attention_range = attention_range
    self.ff_dim = ff_dim
    self.num_heads = num_heads
    self.embed_dim = embed_dim
    self.attn_block_repeats = attn_block_repeats
    self.include_embedding_layer = include_embedding_layer

    self.attention_block = MaskedTransformerBlock(self.embed_dim,
                                                   self.num_heads, self.ff_dim,
                                                   attention_range, start_offset,
                                                   end_offset, attn_block_repeats=1)
    if include_embedding_layer:
      self.embedding = GenoEmbeddings(embed_dim)


  def build(self, input_shape):
    pass

  def call(self, inputs, training):
    if self.include_embedding_layer:
      x = self.embedding(inputs)
    else:
      x = inputs
    x = self.attention_block(x)
    return x

class CrossAttnChunk(layers.Layer):
  def __init__(self, start_offset=0, end_offset=0, n_heads = 8):
    super(CrossAttnChunk, self).__init__()
    self.attention_range = attention_range
    self.start_offset = start_offset
    self.end_offset = end_offset
    self.n_heads = n_heads


  def build(self, input_shape):
    self.local_dim = input_shape[0][-1]
    self.global_dim = input_shape[1][-1]
    self.attention_block = CrossAttentionLayer(self.local_dim, self.global_dim,
                                              self.start_offset, self.end_offset,
                                              n_heads=self.n_heads)
    pass

  def call(self, inputs, training):
    x = inputs
    x = self.attention_block(x)
    return x


## Modules

In [17]:
class ConvBlock(layers.Layer):
  def __init__(self, embed_dim):
    super(ConvBlock, self).__init__()
    self.embed_dim = embed_dim
    self.const = None
    self.conv000 = layers.Conv1D(embed_dim, 3, padding='same', activation=tf.nn.gelu,
                                 kernel_constraint=self.const,
                    )
    self.conv010 = layers.Conv1D(embed_dim, 5, padding='same', activation=tf.nn.gelu,
                                 kernel_constraint=self.const,
                    )
    self.conv011 = layers.Conv1D(embed_dim, 7, padding='same', activation=tf.nn.gelu,
                                 kernel_constraint=self.const,
                    )

    self.conv020 = layers.Conv1D(embed_dim, 7, padding='same', activation=tf.nn.gelu,
                                 kernel_constraint=self.const,
                    )
    self.conv021 = layers.Conv1D(embed_dim, 15, padding='same', activation=tf.nn.gelu,
                                 kernel_constraint=self.const,
                    )
    self.add = layers.Add()

    self.conv100 = layers.Conv1D(embed_dim, 3, padding='same',
                                 activation=tf.nn.gelu,
                                 kernel_constraint=self.const,)
    self.bn0 = layers.BatchNormalization()
    self.bn1 = layers.BatchNormalization()
    self.dw_conv = layers.DepthwiseConv1D(embed_dim, 1, padding='same')
    self.activation = layers.Activation(tf.nn.gelu)

  def call(self, inputs, training):
    # Could add skip connection here?
    xa = self.conv000(inputs)

    xb = self.conv010(xa)
    xb = self.conv011(xb)

    xc = self.conv020(xa)
    xc = self.conv021(xc)

    xa = self.add([xb, xc])
    xa = self.conv100(xa)
    xa = self.bn0(xa)
    xa = self.dw_conv(xa)
    xa = self.bn1(xa)
    xa = self.activation(xa)
    return xa

def chunk_module(embed_dim, num_heads, input_len, input_channels, attention_range,
               start_offset=0, end_offset=0,
               attn_block_repeats=1, include_embedding=False):
  projection_dim = embed_dim
  inputs = layers.Input(shape=(input_len, embed_dim))
  xa = inputs
  xa0 = SelfAttnChunk(projection_dim, num_heads, projection_dim//2, attention_range,
            start_offset, end_offset, 1, include_embedding_layer=False)(xa)

  xa = ConvBlock(projection_dim)(xa0)
  xa_skip = ConvBlock(projection_dim)(xa)

  xa = layers.Dense(projection_dim, activation=tf.nn.gelu)(xa)
  xa = ConvBlock(projection_dim)(xa)
  xa = CrossAttnChunk(0, 0)([xa, xa0])
  xa = layers.Dropout(0.25)(xa)
  xa = ConvBlock(projection_dim)(xa)

  xa = layers.Concatenate(axis=-1)([xa_skip, xa])

  model = keras.Model(inputs=inputs, outputs=xa)
  return model

## Model

In [18]:
class SplitTransformer(keras.Model):
  def __init__(
      self,
      embed_dim,
      num_heads,
      offset_before=0,
      offset_after=0,
      chunk_size=chunk_size,
      activation=tf.nn.gelu,
      dropout_rate=0.25,
      attn_block_repeats=1,
      attention_range=attention_range):
    super(SplitTransformer, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.chunk_size = chunk_size
    self.activation = activation
    self.dropout_rate = dropout_rate
    self.attn_block_repeats = attn_block_repeats
    self.attention_range = attention_range
    self.offset_before = offset_before
    self.offset_after = offset_after

  def build(self, input_shape):
    self.seq_len = input_shape[1]
    self.chunk_starts = list(range(0, input_shape[1], self.chunk_size))
    self.chunk_ends = []
    for cs in self.chunk_starts:
      self.chunk_ends.append(min(cs+self.chunk_size, input_shape[1]))
    self.mask_starts = [max(0, cs-self.attention_range) for cs in self.chunk_starts]
    self.mask_ends = [min(ce+self.attention_range, input_shape[1]) for ce in self.chunk_ends]
    self.chunkers = [chunk_module(self.embed_dim, self.num_heads,
                                  self.mask_ends[i] - self.mask_starts[i],
                                  inChannel, self.attention_range,
                                  start_offset=cs - self.mask_starts[i],
                                  end_offset=self.mask_ends[i]-self.chunk_ends[i],
                                  attn_block_repeats=1, include_embedding=True) for i,cs in enumerate(self.chunk_starts)]

    self.concat_layer = layers.Concatenate(axis=-2)
    self.embedding = GenoEmbeddings(self.embed_dim)
    self.slice_layer = layers.Lambda(lambda x: x[:, self.offset_before:self.seq_len-self.offset_after], name="output_slicer")
    self.after_concat_layer = layers.Conv1D(self.embed_dim//2, 5, padding='same', activation=tf.nn.gelu)
    self.last_conv = layers.Conv1D(inChannel - 1, 5, padding='same', activation=tf.nn.softmax)
    super(SplitTransformer, self).build(input_shape)


  def call(self, inputs):
    x = self.embedding(inputs)
    chunks = [self.chunkers[i](x[:,
                self.mask_starts[i]:self.mask_ends[i]]) for i, chunker\
                                                    in enumerate(self.chunkers)]
    x = self.concat_layer(chunks)
    x = self.after_concat_layer(x)
    x = self.last_conv(x)
    x = self.slice_layer(x)
    return x


In [19]:
class MyCustomLoss(tf.keras.losses.Loss):

  def call(self, y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)

    loss_obj = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM)
    cat_loss = loss_obj(y_true, y_pred)

    loss_obj = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.SUM)
    kl_loss = loss_obj(y_true, y_pred)

    return cat_loss + kl_loss

In [20]:
def create_model(offset_before=0, offset_after=0):
  model =  SplitTransformer(embed_dim=128,
      num_heads=16,
      attn_block_repeats=1,
      chunk_size=chunk_size,
      activation="gelu",
      offset_before=offset_before,
      offset_after=offset_after)
  optimizer = tfa.optimizers.LAMB(learning_rate=learning_rate)
  model.compile(optimizer, loss=MyCustomLoss(), metrics=tf.keras.metrics.CategoricalAccuracy())
  return model


In [21]:
model = create_model(offset_before=2*attention_range, offset_after=2*attention_range)
model.build((1, max_features_len_per_model, inChannel))
model.summary()

Model: "split_transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 model (Functional)          (None, 2000, 256)         4327936   
                                                                 
 model_1 (Functional)        (None, 2000, 256)         4327936   
                                                                 
 model_2 (Functional)        (None, 2000, 256)         4327936   
                                                                 
 model_3 (Functional)        (None, 2000, 256)         4327936   
                                                                 
 model_4 (Functional)        (None, 2000, 256)         4327936   
                                                                 
 model_5 (Functional)        (None, 2000, 256)         4327936   
                                                                 
 model_6 (Functional)        (None, 2000, 256)   

In [22]:
METRIC = "val_loss"

def create_callbacks(kfold=0, metric = METRIC):
    reducelr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor= metric,
        mode='auto',
        factor=0.5,
        patience=3,
        verbose=0
    )

    earlystop = tf.keras.callbacks.EarlyStopping(
        monitor= metric,
        mode='auto',
        patience= 10,
        verbose=1,
        restore_best_weights=True
    )

    callbacks = [
                 reducelr,
                 earlystop
                 ]

    return callbacks

## Training

In [23]:
save_dir = "[save_path]"

if not os.path.exists(save_dir):
  os.makedirs(save_dir)

In [24]:
# A TPU V3-8 has 8 computing cores, the global batch size will be 1/16 x 8 = 8/128
BATCH_SIZE_BASE = 2
# Training configuration
BATCH_SIZE = BATCH_SIZE_BASE * N_REPLICAS if TPU else 32
BATCH_SIZE

16

In [25]:
dl.VARIANT_COUNT

31143

## Timing test

The first time it tries to predict, takes longer (maybe because it tries to re-build the model on the local machine) but subsequent calls are much faster.

In [26]:
# 128 - 32
NUM_EPOCHS = 1000
results = None
x_train_indices, x_valid_indices = train_test_split(range(dl.get_ref_set(0, 1).shape[0]),
                                                    test_size=0.10,
                                                      random_state=2022,
                                                      shuffle=True,)
steps_per_epoch = len(x_train_indices)//BATCH_SIZE
validation_steps = len(x_valid_indices)//BATCH_SIZE

break_points = list(np.arange(0, dl.VARIANT_COUNT, max_features_len_per_model)) + [dl.VARIANT_COUNT]

for w in range(len(break_points)-1):
  print(f"Doing part {w} out of {len(break_points)-2}")
  final_start_pos = max(0, break_points[w]-2*attention_range)
  final_end_pos = min(dl.VARIANT_COUNT, break_points[w+1]+2*attention_range)
  offset_before = break_points[w] - final_start_pos
  offset_after = final_end_pos - break_points[w+1]
  ref_set = dl.get_ref_set(final_start_pos, final_end_pos).astype(np.int32)
  print(f"Data shape: {ref_set.shape}")
  train_dataset = get_dataset(ref_set[x_train_indices], BATCH_SIZE,
                              offset_before=offset_before,
                              offset_after=offset_after)
  valid_dataset = get_dataset(ref_set[x_valid_indices], BATCH_SIZE,
                              offset_before=offset_before,
                              offset_after=offset_after, training=False)
  del ref_set
  K.clear_session()
  tf.tpu.experimental.initialize_tpu_system(TPU)
  callbacks = create_callbacks()
  with strategy.scope():
    model = create_model(offset_before=offset_before,
                              offset_after=offset_after)
    history = model.fit(train_dataset, steps_per_epoch=steps_per_epoch, epochs=1,
            validation_data=valid_dataset,
            validation_steps=validation_steps,
            callbacks=callbacks, verbose=1)

  save_name = save_dir + f"beadchip_hmr_probs_window_{w+1}"
  test_dataset_np = dl.get_target_set(final_start_pos, final_end_pos).astype(np.int32)
  test_dataset = get_test_dataset(test_dataset_np, BATCH_SIZE*4)
  predict_onehot = model.predict(test_dataset, verbose=1)


Doing part 0 out of 0
Data shape: (4808, 31143)






In [27]:
predict_onehot = model.predict(test_dataset, verbose=1)



## Training loop

In [None]:
# 128 - 32
NUM_EPOCHS = 1000
results = None
x_train_indices, x_valid_indices = train_test_split(range(dl.get_ref_set(0, 1).shape[0]), test_size=0.10,
                                      random_state=2022,
                                      shuffle=True,)
steps_per_epoch = len(x_train_indices)//BATCH_SIZE
validation_steps = len(x_valid_indices)//BATCH_SIZE

break_points = list(np.arange(0, dl.VARIANT_COUNT, max_features_len_per_model)) + [dl.VARIANT_COUNT]

for w in range(len(break_points)-1):
  print(f"Doing part {w} out of {len(break_points)-2}")
  final_start_pos = max(0, break_points[w]-2*attention_range)
  final_end_pos = min(dl.VARIANT_COUNT, break_points[w+1]+2*attention_range)
  offset_before = break_points[w] - final_start_pos
  offset_after = final_end_pos - break_points[w+1]
  ref_set = dl.get_ref_set(final_start_pos, final_end_pos).astype(np.int32)
  print(f"Data shape: {ref_set.shape}")
  train_dataset = get_dataset(ref_set[x_train_indices], BATCH_SIZE,
                              offset_before=offset_before,
                              offset_after=offset_after)
  valid_dataset = get_dataset(ref_set[x_valid_indices], BATCH_SIZE,
                              offset_before=offset_before,
                              offset_after=offset_after, training=False)
  del ref_set
  K.clear_session()
  tf.tpu.experimental.initialize_tpu_system(TPU)
  callbacks = create_callbacks()
  with strategy.scope():
    model = create_model(offset_before=offset_before,
                              offset_after=offset_after)
    history = model.fit(train_dataset, steps_per_epoch=steps_per_epoch, epochs=NUM_EPOCHS,
            validation_data=valid_dataset,
            validation_steps=validation_steps,
            callbacks=callbacks, verbose=1)

  save_name = save_dir + f"beadchip_hmr_probs_window_{w+1}"
  test_dataset_np = dl.get_target_set(final_start_pos, final_end_pos).astype(np.int32)
  test_dataset = get_test_dataset(test_dataset_np, BATCH_SIZE)
  predict_onehot = model.predict(test_dataset, verbose=1)
  predict_onehot = predict_onehot.astype(np.float32)

  # test_X_missing = to_categorical(dl.get_target_set(break_points[w], break_points[w+1]).astype(np.int32), dl.SEQ_DEPTH)
  # predict_onehot = np.empty((test_X_missing.shape[0], test_X_missing.shape[1], dl.SEQ_DEPTH-1), dtype=np.float32)
  # for i in tqdm(range(len(test_X_missing))):
  #   predict_onehot[i] = model.predict(test_X_missing[i:i+1], verbose=0)
  np.save(save_name, predict_onehot)
  print("=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-")

Doing part 0 out of 0
Data shape: (4808, 31143)




Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

In [None]:
dl.reverse_replacement_dict

{0: '0|0', 1: '0|1', 2: '1|0', 3: '1|1', 4: '.|.'}

In [28]:
from google.colab import runtime
runtime.unassign()