In [None]:
import os
import pandas as pd
from tensorflow import keras
from sklearn.model_selection import train_test_split
from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

In [None]:
#import os
#import pandas as pd
#from tensorflow import keras
#from sklearn.model_selection import train_test_split
import numpy as np

import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K

class GlobalAttention(keras.layers.Layer):
    
    '''
    Recevies two inputs:
    1. A global representation (of some fixed dimension)
    2. A sequence (of any length, and some fixed dimension)
    The global representation is used to construct a global query that attends to all the positions in the sequence (independently
    for any of the heads).
    '''
    
    def __init__(self, n_heads, d_key, d_value, **kwargs):
        self.n_heads = n_heads
        self.d_key = d_key
        self.sqrt_d_key = np.sqrt(self.d_key)
        self.d_value = d_value
        self.d_output = n_heads * d_value
        super(GlobalAttention, self).__init__(**kwargs)
        
    def compute_output_shape(self, input_shapes):
        # input_shapes: (batch_size, d_global_input), (batch_size, length, d_seq_input)
        (batch_size, _), _ = input_shapes
        return (batch_size, self.d_output)

    def build(self, input_shapes):
        # input_shapes: (batch_size, d_global_input), (batch_size, length, d_seq_input)
        (_, self.d_global_input), (_, _, self.d_seq_input) = input_shapes
        # Wq: (n_heads, d_global_input, d_key)
        self.Wq = self.add_weight(name = 'Wq', shape = (self.n_heads, self.d_global_input, self.d_key), \
                initializer = 'glorot_uniform', trainable = True)
        # Wk: (n_heads, d_seq_input, d_key)
        self.Wk = self.add_weight(name = 'Wk', shape = (self.n_heads, self.d_seq_input, self.d_key), \
                initializer = 'glorot_uniform', trainable = True)
        # Wv: (n_heads, d_seq_input, d_value)
        self.Wv = self.add_weight(name = 'Wv', shape = (self.n_heads, self.d_seq_input, self.d_value), \
                initializer = 'glorot_uniform', trainable = True)
        super(GlobalAttention, self).build(input_shapes)

    def call(self, inputs):
    
        # X: (batch_size, d_global_input)
        # S: (batch_size, length, d_seq_input)
        X, S = inputs
        _, length, _ = K.int_shape(S)
    
        # (batch_size, n_heads, length, d_value)
        VS = K.permute_dimensions(keras.activations.gelu(K.dot(S, self.Wv)), (0, 2, 1, 3))
        # (batch_size * n_heads, length, d_value)
        VS_batched_heads = K.reshape(VS, (-1, length, self.d_value))
        
        Z_batched_heads = self.calculate_attention(inputs)
        # (batch_size * n_heads, d_value)
        Y_batched_heads = K.batch_dot(Z_batched_heads, VS_batched_heads)
        # (batch_size, n_heads * d_value)
        Y = K.reshape(Y_batched_heads, (-1, self.d_output))
        
        return Y
        
    def calculate_attention(self, inputs):
    
        # X: (batch_size, d_global_input)
        # S: (batch_size, length, d_seq_input)
        X, S = inputs
        _, length, _ = K.int_shape(S)
                
        # (batch_size, n_heads, d_key)
        QX = K.tanh(K.dot(X, self.Wq))
        # (batch_size * n_heads, d_key)
        QX_batched_heads = K.reshape(QX, (-1, self.d_key))
        
        # (batch_size, n_heads, d_key, length)
        KS = K.permute_dimensions(K.tanh(K.dot(S, self.Wk)), (0, 2, 3, 1))
        # (batch_size * n_heads, d_key, length)
        KS_batched_heads = K.reshape(KS, (-1, self.d_key, length))
                
        # (batch_size * n_heads, length)
        Z_batched_heads = K.softmax(K.batch_dot(QX_batched_heads, KS_batched_heads) / self.sqrt_d_key)
        return Z_batched_heads
    
def create_model(seq_len, vocab_size, n_annotations, d_hidden_seq = 128, d_hidden_global = 512, n_blocks = 6, n_heads = 4, \
         d_key = 64, conv_kernel_size = 9, wide_conv_dilation_rate = 5, activation = 'gelu'):
    
    '''
    seq_len is required to create the model, but all the weights are independent of the length and can be re-used with
    different lengths.
    '''
    
    assert d_hidden_global % n_heads == 0
    d_value = d_hidden_global // n_heads
    
    input_seq = keras.layers.Input(shape = (seq_len,), dtype = np.int32, name = 'input-seq')
    input_annotations = keras.layers.Input(shape = (n_annotations,), dtype = np.float32, name = 'input-annotations')
    
    hidden_seq = keras.layers.Embedding(vocab_size, d_hidden_seq, name = 'embedding-seq-input')(input_seq)
    hidden_global = keras.layers.Dense(d_hidden_global, activation = activation, name = 'dense-global-input')(input_annotations)
    
    for block_index in range(1, n_blocks + 1):
        
        seqed_global = keras.layers.Dense(d_hidden_seq, activation = activation, name = 'global-to-seq-dense-block%d' % block_index)(hidden_global)
        seqed_global = keras.layers.Reshape((1, d_hidden_seq), name = 'global-to-seq-reshape-block%d' % block_index)(seqed_global)
        
        narrow_conv_seq = keras.layers.Conv1D(filters = d_hidden_seq, kernel_size = conv_kernel_size, strides = 1, \
                padding = 'same', dilation_rate = 1, activation = activation, name = 'narrow-conv-block%d' % block_index)(hidden_seq)
        wide_conv_seq = keras.layers.Conv1D(filters = d_hidden_seq, kernel_size = conv_kernel_size, strides = 1, \
                padding = 'same', dilation_rate = wide_conv_dilation_rate, activation = activation, name = 'wide-conv-block%d' % \
                block_index)(hidden_seq)
        
        hidden_seq = keras.layers.Add(name = 'seq-merge1-block%d' % block_index)([hidden_seq, seqed_global, narrow_conv_seq, wide_conv_seq])
        hidden_seq = keras.layers.LayerNormalization(name = 'seq-merge1-norm-block%d' % block_index)(hidden_seq)
        
        dense_seq = keras.layers.Dense(d_hidden_seq, activation = activation, name = 'seq-dense-block%d' % block_index)(hidden_seq)
        hidden_seq = keras.layers.Add(name = 'seq-merge2-block%d' % block_index)([hidden_seq, dense_seq])
        hidden_seq = keras.layers.LayerNormalization(name = 'seq-merge2-norm-block%d' % block_index)(hidden_seq)
        
        dense_global = keras.layers.Dense(d_hidden_global, activation = activation, name = 'global-dense1-block%d' % block_index)(hidden_global)
        attention = GlobalAttention(n_heads, d_key, d_value, name = 'global-attention-block%d' % block_index)([hidden_global, hidden_seq])
        hidden_global = keras.layers.Add(name = 'global-merge1-block%d' % block_index)([hidden_global, dense_global, attention])
        hidden_global = keras.layers.LayerNormalization(name = 'global-merge1-norm-block%d' % block_index)(hidden_global)
        
        dense_global = keras.layers.Dense(d_hidden_global, activation = activation, name = 'global-dense2-block%d' % block_index)(hidden_global)
        hidden_global = keras.layers.Add(name = 'global-merge2-block%d' % block_index)([hidden_global, dense_global])
        hidden_global = keras.layers.LayerNormalization(name = 'global-merge2-norm-block%d' % block_index)(hidden_global)
        
    output_seq = keras.layers.Dense(vocab_size, activation = 'softmax', name = 'output-seq')(hidden_seq)
    output_annotations = keras.layers.Dense(n_annotations, activation = 'sigmoid', name = 'output-annotations')(hidden_global)

    return keras.models.Model(inputs = [input_seq, input_annotations], outputs = [output_seq, output_annotations])
    
def get_model_with_hidden_layers_as_outputs(model):
    
    _, seq_len, _ = model.outputs[0].shape
    
    seq_layers = [layer.output for layer in model.layers if len(layer.output.shape) == 3 and \
            tuple(layer.output.shape)[:2] == (None, seq_len) and (layer.name in ['input-seq-encoding', 'dense-seq-input', 'output-seq'] or \
            isinstance(layer, keras.layers.LayerNormalization))]
    global_layers = [layer.output for layer in model.layers if len(layer.output.shape) == 2 and (layer.name in ['input_annotations', \
            'dense-global-input', 'output-annotations'] or isinstance(layer, keras.layers.LayerNormalization))]
    
    concatenated_seq_output = keras.layers.Concatenate(name = 'all-seq-layers')(seq_layers)
    concatenated_global_output = keras.layers.Concatenate(name = 'all-global-layers')(global_layers)
    
    return keras.models.Model(inputs = model.inputs, outputs = [concatenated_seq_output, concatenated_global_output])
    


In [2]:
import numpy as np
import pandas as pd

import os
import itertools
from datetime import datetime, timedelta
import pickle

import numpy as np
import pandas as pd
import h5py

from tensorflow import keras


DEFAULT_EPISODE_SETTINGS = [
    # seq_len, batch_size
    (128, 128),
    (512, 64),
    (1024, 32),
]

def run_pretraining(create_model_function, epoch_generator, h5_dataset_file_path, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, \
        other_optimizer_kwargs = {}, annots_loss_weight = 1, autosave_manager = None, weights_dir = None, resume_from = None, n_epochs = None, fit_callbacks = []):

    np.random.seed(0)
    
    with h5py.File(h5_dataset_file_path, 'r') as h5f:
        n_annotations = len(h5f['included_annotations'])
    
    model_generator = PretrainingModelGenerator(create_model_function, n_annotations, create_model_kwargs = create_model_kwargs, optimizer_class = optimizer_class, lr = lr, \
            other_optimizer_kwargs = other_optimizer_kwargs, annots_loss_weight = annots_loss_weight)
    model_trainer = ModelTrainer(model_generator, epoch_generator, autosave_manager = autosave_manager, weights_dir = weights_dir, fit_callbacks = fit_callbacks)

    with h5py.File(h5_dataset_file_path, 'r') as h5f:
        model_trainer.setup(DatasetHandler(h5f), resume_from = resume_from)
        model_trainer.train(n_epochs = n_epochs)
        
    return model_trainer
    
class ModelTrainer:
    
    def __init__(self, model_generator, epoch_generator, autosave_manager = None, weights_dir = None, fit_callbacks = []):
        
        self.model_generator = model_generator
        self.epoch_generator = epoch_generator
        self.autosave_manager = autosave_manager
        self.weights_dir = weights_dir
        self.fit_callbacks = fit_callbacks
        
        if self.autosave_manager is not None:
            self.autosave_manager.n_annotations = self.model_generator.n_annotations
        
    def setup(self, dataset_handler, resume_from = None):
        
        if resume_from is None:
            self.current_epoch_index = 0
            start_sample_index = 0
            resumed_weights_file_path = None
        else:
            self.current_epoch_index, start_sample_index = resume_from
            self.current_epoch_index += 1
            resumed_weights_file_path = os.path.join(self.weights_dir, 'epoch_%d_sample_%d.pkl' % resume_from)
        
        starting_episode = self.epoch_generator.setup(dataset_handler, start_sample_index)
        self.model_generator.dummy_epoch = self.epoch_generator.create_dummpy_epoch()[:2]
        log('Starting with episode with seq_len = %d.' % starting_episode.seq_len)
        
        if resumed_weights_file_path is not None:
            with open(resumed_weights_file_path, 'rb') as f:
                n_annotations, self.model_generator.model_weights, self.model_generator.optimizer_weights = pickle.load(f)
                assert n_annotations == self.model_generator.n_annotations
                log('Loaded weights from %s.' % resumed_weights_file_path)
        
        self.model = self.model_generator.create_model(starting_episode.seq_len)
        self.model.summary()
                    
    def train(self, n_epochs = None, autosave = True):
        for _ in (itertools.count() if n_epochs is None else range(n_epochs)):
            self.train_next_epoch(autosave = autosave)
        
    def train_next_epoch(self, autosave = True):
    
        changed_episode, episode = self.epoch_generator.determine_episode_and_ready_next_epoch()
        
        if changed_episode:
            log('Starting a new episode with seq_len = %d.' % episode.seq_len)
            self.model_generator.dummy_epoch = self.epoch_generator.create_dummpy_epoch()[:2]
            self.model_generator.update_state(self.model)
            self.model = self.model_generator.create_model(episode.seq_len)
        
        X, Y, sample_weights = self.epoch_generator.create_next_epoch()
        log('Epoch %d (current sample %d):' % (self.current_epoch_index, self.epoch_generator.current_sample_index))
        self.model.fit(X, Y, sample_weight = sample_weights, batch_size = episode.batch_size, callbacks = self.fit_callbacks)
        
        if autosave and self.autosave_manager is not None:
            self.autosave_manager.on_epoch_end(self.model, self.current_epoch_index, self.epoch_generator.current_sample_index)
            
        self.current_epoch_index += 1

class EpochGenerator:
    
    def __init__(self, n_batches_per_epoch = 100, p_seq_noise = 0.05, p_no_input_annot = 0.5, p_annot_noise_positive = 0.25, \
            p_annot_noise_negative = 1e-04, load_chunk_size = 100000, min_time_per_episode = timedelta(minutes = 15), \
            episode_settings = DEFAULT_EPISODE_SETTINGS):
        
        self.n_batches_per_epoch = n_batches_per_epoch
        self.p_seq_noise = p_seq_noise
        self.p_no_input_annot = p_no_input_annot
        self.p_annot_noise_positive = p_annot_noise_positive
        self.p_annot_noise_negative = p_annot_noise_negative
        self.load_chunk_size = load_chunk_size
        self.min_time_per_episode = min_time_per_episode
        
        self.episode_managers = [EpisodeDataManager(seq_len, batch_size, self.n_batches_per_epoch) for seq_len, batch_size in \
                episode_settings]
        self.episode_seq_lens = np.array([episode_manager.seq_len for episode_manager in self.episode_managers])
        
    def setup(self, dataset_handler, start_sample_index = 0):
        self.dataset_handler = dataset_handler
        self.current_sample_index = start_sample_index % self.dataset_handler.total_size
        self._load_chunk()
        self._select_new_episode()
        return self._current_episode
    
    def determine_episode_and_ready_next_epoch(self):
        
        if self._episode_selection_time + self.min_time_per_episode <= datetime.now():
            old_episode = self._current_episode
            self._select_new_episode()
            changed_episode = (self._current_episode is not old_episode)
        else:
            changed_episode = False
            
        while not self._current_episode.is_epoch_ready():
            self._load_chunk()

        return changed_episode, self._current_episode
        
    def create_next_epoch(self):
        return self._encode_epoch(*self.create_next_epoch_Y())
        
    def create_dummpy_epoch(self, size = 1):
        return self._encode_epoch(*self.create_next_dummy_epoch_Y(size))
        
    def create_next_epoch_Y(self):
        assert self._current_episode.is_epoch_ready()
        return self._current_episode.encode_next_epoch()
    
    def create_next_dummy_epoch_Y(self, size = 1):
        
        while not self._current_episode.is_epoch_ready(size):
            self._load_chunk()
            
        return self._current_episode.encode_dummy_epoch(size)
    
    def _select_new_episode(self):
        self._current_episode = max(self.episode_managers, key = lambda episode_manager: len(episode_manager.sample_cache))
        self._episode_selection_time = datetime.now()
            
    def _load_chunk(self):
        
        chunk_sample_cache = self.dataset_handler[self.current_sample_index:(self.current_sample_index + self.load_chunk_size)]
        self.current_sample_index += self.load_chunk_size
        
        if self.current_sample_index >= self.dataset_handler.total_size:
            self.current_sample_index = 0
            
        self._assign_samples(chunk_sample_cache)
        
    def _assign_samples(self, sample_cache):
        
        seq_lens = np.array(list(map(len, sample_cache.seqs))) + ADDED_TOKENS_PER_SEQ
        assigned_episode_indices = self._select_episodes_to_assign(seq_lens)
        
        for episode_manager_index, episode_manager in enumerate(self.episode_managers):
            sample_indices_for_episode, = np.where(assigned_episode_indices == episode_manager_index)
            episode_manager.sample_cache.extend(sample_cache.slice_indices(sample_indices_for_episode))
        
    def _select_episodes_to_assign(self, seq_lens, gamma = 1):
        # The smaller the distance between a sample's sequence length to an episode's maximum sequence length, the higher the chance
        # that it will be assigned to that episode.
        samples_by_episodes_seq_len_ratio = seq_lens.reshape(-1, 1) / self.episode_seq_lens.reshape(1, -1)
        samples_by_episodes_seq_len_symmetric_ratio = np.maximum(samples_by_episodes_seq_len_ratio, 1 / samples_by_episodes_seq_len_ratio)
        raw_samples_by_episodes_probs = np.exp(-gamma * samples_by_episodes_seq_len_symmetric_ratio)
        samples_by_episodes_probs = raw_samples_by_episodes_probs / raw_samples_by_episodes_probs.sum(axis = -1).reshape(-1, 1)
        samples_by_episodes_cum_probs = samples_by_episodes_probs.cumsum(axis = -1)
        assigned_episode_indices = (np.random.rand(len(seq_lens), 1) <= samples_by_episodes_cum_probs).argmax(axis = 1)
        return assigned_episode_indices
    
    def _encode_epoch(self, encoded_seqs, encoded_annotation_masks):
        
        seqs_noise_mask = np.random.choice([True, False], encoded_seqs.shape, p = [1 - self.p_seq_noise, self.p_seq_noise])
        random_seq_tokens = np.random.randint(0, n_tokens, encoded_seqs.shape)
        noisy_encoded_seqs = np.where(seqs_noise_mask, encoded_seqs, random_seq_tokens)

        noisy_annotations_when_positive = np.random.choice([True, False], encoded_annotation_masks.shape, \
                p = [1 - self.p_annot_noise_positive, self.p_annot_noise_positive])
        noisy_annotations_when_negative = np.random.choice([True, False], encoded_annotation_masks.shape, \
                p = [self.p_annot_noise_negative, 1 - self.p_annot_noise_negative])
        noisy_annotation_masks = np.where(encoded_annotation_masks, noisy_annotations_when_positive, \
                noisy_annotations_when_negative)
        noisy_annotation_masks[np.random.choice([True, False], len(noisy_annotation_masks), p = [self.p_no_input_annot, \
                1 - self.p_no_input_annot]), :] = False

        seq_weights = (encoded_seqs != additional_token_to_index['<PAD>']).astype(float)
        # When a protein has no annotations at all, we don't know whether it's because such annotations don't exist or just not found,
        # so it's safer to set the loss weight of those annotations to zero.
        annotation_weights = encoded_annotation_masks.any(axis = -1).astype(float)
        
        X = [noisy_encoded_seqs, noisy_annotation_masks.astype(np.int8)]
        Y = [np.expand_dims(encoded_seqs, axis = -1), encoded_annotation_masks.astype(np.int8)]
        sample_weigths = [seq_weights, annotation_weights]
        
        return X, Y, sample_weigths

class EpisodeDataManager:
    
    def __init__(self, seq_len, batch_size, n_batches_per_epoch):
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.n_batches_per_epoch = n_batches_per_epoch
        self.epoch_size = self.n_batches_per_epoch * self.batch_size
        self.sample_cache = SampleCache()
        
    def is_epoch_ready(self, n_required_samples = None):
        return len(self.sample_cache) >= self._resolve_epoch_size(n_required_samples)
    
    def get_next_raw_epoch(self, size = None):
        return self.sample_cache.pop(self._resolve_epoch_size(size))
    
    def peek_raw_epoch(self, size = None):
        return self.sample_cache.slice_first(self._resolve_epoch_size(size))
    
    def encode_next_epoch(self, log_length_dist = True):
        
        seq_lengths, encoded_seqs, encoded_annotation_masks = self._encode_epoch(self.get_next_raw_epoch())
        
        if log_length_dist:
            log('Epoch sequence length distribution (for seq_len = %d): %s' % (self.seq_len, \
                    ', '.join('%s: %s' % item for item in pd.Series(seq_lengths).describe().iteritems())))
        
        return encoded_seqs, encoded_annotation_masks
    
    def encode_dummy_epoch(self, size = 1):
        seq_lengths, encoded_seqs, encoded_annotation_masks = self._encode_epoch(self.peek_raw_epoch(size))
        return encoded_seqs, encoded_annotation_masks
    
    def _encode_epoch(self, epoch_sample_cache):
        
        pad_token_index = additional_token_to_index['<PAD>']
        tokenized_seqs = list(map(tokenize_seq, epoch_sample_cache.seqs))
        seq_lengths = np.array(list(map(len, tokenized_seqs)))
        max_offsets = np.maximum(seq_lengths - self.seq_len, 0)
        chosen_offsets = (np.random.rand(self.epoch_size) * (max_offsets + 1)).astype(int)
        trimmed_tokenized_seqs = [seq_tokens[chosen_offset:(chosen_offset + self.seq_len)] for seq_tokens, chosen_offset in \
                zip(tokenized_seqs, chosen_offsets)]
        encoded_seqs = np.array([seq_tokens + max(self.seq_len - len(seq_tokens), 0) * [pad_token_index] for seq_tokens in \
                trimmed_tokenized_seqs]).astype(np.int8)
        
        encoded_annotation_masks = np.concatenate([annotation_mask.reshape(1, -1) for annotation_mask in \
                epoch_sample_cache.annotation_masks], axis = 0).astype(bool)
        
        # We hide the annotations of test-set samples to avoid "cheating" on downstream fine-tuning tests. Note that by removing all of the annotations,
        # EpochGenerator._encode_epoch will then set the annotation_weights for these records to 0, meaning they will not be part of the loss function.
        encoded_annotation_masks[epoch_sample_cache.test_set_mask, :] = False
        
        return seq_lengths, encoded_seqs, encoded_annotation_masks
    
    def _resolve_epoch_size(self, size):
        if size is None:
            return self.epoch_size
        else:
            return size

class DatasetHandler:
    
    def __init__(self, dataset_h5f):
        self.dataset_h5f = dataset_h5f
        self.total_size = len(dataset_h5f['seq_lengths'])
        
    def __getitem__(self, slicing):
        return SampleCache(list(map(parse_seq, self.dataset_h5f['seqs'][slicing])), self.dataset_h5f['annotation_masks'][slicing], \
                self.dataset_h5f['test_set_mask'][slicing])

class SampleCache:
    
    def __init__(self, seqs = [], annotation_masks = [], test_set_mask = []):
        self.seqs = list(seqs)
        self.annotation_masks = list(annotation_masks)
        self.test_set_mask = list(test_set_mask)
        
    def extend(self, other_cache):
        self.seqs.extend(other_cache.seqs)
        self.annotation_masks.extend(other_cache.annotation_masks)
        self.test_set_mask.extend(other_cache.test_set_mask)
        
    def pop(self, n):
        popped_sample_cache = self.slice_first(n)
        self.seqs = self.seqs[n:]
        self.annotation_masks = self.annotation_masks[n:]
        self.test_set_mask = self.test_set_mask[n:]
        return popped_sample_cache
    
    def slice_first(self, n):
        return SampleCache(self.seqs[:n], self.annotation_masks[:n], self.test_set_mask[:n])
        
    def slice_indices(self, indices):
        return SampleCache([self.seqs[i] for i in indices], [self.annotation_masks[i] for i in indices], \
                [self.test_set_mask[i] for i in indices])
    
    def __len__(self):
        assert len(self.seqs) == len(self.annotation_masks) == len(self.test_set_mask)
        return len(self.seqs)
        
class AutoSaveManager:
    
    def __init__(self, directory, every_epochs_to_save = 10, every_saves_to_keep = 25):
        self.directory = directory
        self.every_epochs_to_save = every_epochs_to_save
        self.every_saves_to_keep = every_saves_to_keep
        self.last_saved_path_to_delete = None
        self.n_saves = 0
    
    def on_epoch_end(self, model, epoch_index, sample_index):
        
        if epoch_index % self.every_epochs_to_save != 0:
            return
        
        save_path = os.path.join(self.directory, 'epoch_%d_sample_%d.pkl' % (epoch_index, sample_index))
        _save_model_state(model, self.n_annotations, save_path)
        self.n_saves += 1
        
        if self.last_saved_path_to_delete is not None:
            os.remove(self.last_saved_path_to_delete)
            
        if self.n_saves % self.every_saves_to_keep == 0:
            self.last_saved_path_to_delete = None
        else:
            self.last_saved_path_to_delete = save_path

def _save_model_state(model, n_annotations, path):
    with open(path, 'wb') as f:
        pickle.dump((n_annotations, model.get_weights(), model.optimizer.get_weights()), f)
        

class ModelGenerator:

    def __init__(self, optimizer_class = keras.optimizers.Adam, lr = 2e-04, other_optimizer_kwargs = {}, model_weights = None, optimizer_weights = None):
        self.optimizer_class = optimizer_class
        self.lr = lr
        self.other_optimizer_kwargs = other_optimizer_kwargs
        self.model_weights = model_weights
        self.optimizer_weights = optimizer_weights
        
    def train(self, encoded_train_set, encoded_valid_set, seq_len, batch_size, n_epochs, lr = None, callbacks = [], **create_model_kwargs):
    
        train_X, train_Y, train_sample_weigths = encoded_train_set
        self.dummy_epoch = (_slice_arrays(train_X, slice(0, 1)), _slice_arrays(train_Y, slice(0, 1)))
        model = self.create_model(seq_len, **create_model_kwargs)
        
        if lr is not None:
            model.optimizer.lr = lr
        
        model.fit(train_X, train_Y, sample_weight = train_sample_weigths, batch_size = batch_size, epochs = n_epochs, validation_data = encoded_valid_set, \
                callbacks = callbacks)
        self.update_state(model)
        
    # def update_state(self, model):
    #     self.model_weights = copy_weights([w.numpy() for w in model.variables])
    #     self.optimizer_weights = copy_weights([w.numpy() for w in model.optimizer.variables()])
        
    def update_state(self, model):
        self.model_weights = copy_weights([w.numpy() for w in model.variables])
        self.optimizer_weights = copy_weights([w.numpy() for w in model.optimizer.variables])
        
    def _init_weights(self, model):
    
        if self.optimizer_weights is not None:
            # For some reason keras requires this strange little hack in order to properly initialize a new model's optimizer, so that
            # the optimizer's weights can be reloaded from an existing state.
            self._train_for_a_dummy_epoch(model)
            
        if self.model_weights is not None:
            print(len(copy_weights(self.model_weights)))
            model.set_weights(copy_weights(self.model_weights))
        
        if self.optimizer_weights is not None:
            if len(self.optimizer_weights) == len(model.optimizer.variables()):
                model.optimizer.set_weights(copy_weights(self.optimizer_weights))
            else:
                log('Incompatible number of optimizer weights - will not initialize them.')
            
    def _train_for_a_dummy_epoch(self, model):
        X, Y = self.dummy_epoch
        model.fit(X, Y, batch_size = 1, verbose = 0)
        
class PretrainingModelGenerator(ModelGenerator):

    def __init__(self, create_model_function, n_annotations, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, other_optimizer_kwargs = {}, \
            annots_loss_weight = 1, model_weights = None, optimizer_weights = None):
        
        ModelGenerator.__init__(self, optimizer_class = optimizer_class, lr = lr, other_optimizer_kwargs = other_optimizer_kwargs, model_weights = model_weights, \
                optimizer_weights = optimizer_weights)
        
        self.create_model_function = create_model_function
        self.n_annotations = n_annotations
        self.create_model_kwargs = create_model_kwargs
        self.annots_loss_weight = annots_loss_weight
        
    def create_model(self, seq_len, compile = True, init_weights = True):
        
        clear_session()
        model = self.create_model_function(seq_len, n_tokens, self.n_annotations, **self.create_model_kwargs)
        
        if compile:
            model.compile(optimizer =self.optimizer_class(learning_rate = self.lr, **self.other_optimizer_kwargs), loss = ['sparse_categorical_crossentropy', 'binary_crossentropy'], \
                    loss_weights = [1, self.annots_loss_weight])
        
        if init_weights:
            self._init_weights(model)
        
        return model
        
class FinetuningModelGenerator(ModelGenerator):

    def __init__(self, pretraining_model_generator, output_spec, pretraining_model_manipulation_function = None, dropout_rate = 0.5, optimizer_class = None, \
            lr = None, other_optimizer_kwargs = None, model_weights = None, optimizer_weights = None):
        
        if other_optimizer_kwargs is None:
            if optimizer_class is None:
                other_optimizer_kwargs = pretraining_model_generator.other_optimizer_kwargs
            else:
                other_optimizer_kwargs = {}
        
        if optimizer_class is None:
            optimizer_class = pretraining_model_generator.optimizer_class
            
        if lr is None:
            lr = pretraining_model_generator.lr
            
        ModelGenerator.__init__(self, optimizer_class = optimizer_class, lr = lr, other_optimizer_kwargs = other_optimizer_kwargs, model_weights = model_weights, \
                optimizer_weights = optimizer_weights)
        
        self.pretraining_model_generator = pretraining_model_generator
        self.output_spec = output_spec
        self.pretraining_model_manipulation_function = pretraining_model_manipulation_function
        self.dropout_rate = dropout_rate
                    
    def create_model(self, seq_len, freeze_pretrained_layers = False):
        
        model = self.pretraining_model_generator.create_model(seq_len, compile = False, init_weights = (self.model_weights is None))
            
        if self.pretraining_model_manipulation_function is not None:
            model = self.pretraining_model_manipulation_function(model)
            
        if freeze_pretrained_layers:
            for layer in model.layers:
                layer.trainable = False
        
        model_inputs = model.input
        pretraining_output_seq_layer, pretraining_output_annoatations_layer = model.output
        last_hidden_layer = pretraining_output_seq_layer if self.output_spec.output_type.is_seq else pretraining_output_annoatations_layer
        last_hidden_layer = keras.layers.Dropout(self.dropout_rate)(last_hidden_layer)
        
        if self.output_spec.output_type.is_categorical:
            output_layer = keras.layers.Dense(len(self.output_spec.unique_labels), activation = 'softmax')(last_hidden_layer)
            loss = 'sparse_categorical_crossentropy'
        elif self.output_spec.output_type.is_binary:
            output_layer = keras.layers.Dense(1, activation = 'sigmoid')(last_hidden_layer)
            loss = 'binary_crossentropy'
        elif self.output_spec.output_type.is_numeric:
            output_layer = keras.layers.Dense(1, activation = None)(last_hidden_layer)
            loss = 'mse'
        else:
            raise ValueError('Unexpected global output type: %s' % self.output_spec.output_type)
                
        model = keras.models.Model(inputs = model_inputs, outputs = output_layer)
        model.compile(loss = loss, optimizer =self.optimizer_class(learning_rate = self.lr, **self.other_optimizer_kwargs))
        
        self._init_weights(model)
                
        return model
                        
class InputEncoder:

    def __init__(self, n_annotations):
        self.n_annotations = n_annotations

    def encode_X(self, seqs, seq_len):
        return [
            tokenize_seqs(seqs, seq_len),
            np.zeros((len(seqs), self.n_annotations), dtype = np.int8)
        ]
        
def load_pretrained_model_from_dump(dump_file_path, create_model_function, create_model_kwargs = {}, optimizer_class = keras.optimizers.Adam, lr = 2e-04, \
        other_optimizer_kwargs = {}, annots_loss_weight = 1, load_optimizer_weights = False):
    
    with open(dump_file_path, 'rb') as f:
        n_annotations, model_weights, optimizer_weights = pickle.load(f)
        
    if not load_optimizer_weights:
        optimizer_weights = None
    
    model_generator = PretrainingModelGenerator(create_model_function, n_annotations, create_model_kwargs = create_model_kwargs, optimizer_class = optimizer_class, lr = lr, \
            other_optimizer_kwargs = other_optimizer_kwargs, annots_loss_weight = annots_loss_weight, model_weights = model_weights, optimizer_weights = optimizer_weights)
    input_encoder = InputEncoder(n_annotations)
    
    return model_generator, input_encoder

def tokenize_seqs(seqs, seq_len):
    # Note that tokenize_seq already adds <START> and <END> tokens.
    return np.array([seq_tokens + (seq_len - len(seq_tokens)) * [additional_token_to_index['<PAD>']] for seq_tokens in map(tokenize_seq, seqs)], dtype = np.int32)
    
def clear_session():
    import tensorflow.keras.backend as K
    K.clear_session()
    
def copy_weights(weights):
    return [_copy_number_or_array(w) for w in weights]
    
def _copy_number_or_array(variable):
    if isinstance(variable, np.ndarray):
        return variable.copy()
    elif isinstance(variable, Number):
        return variable
    else:
        raise TypeError('Unexpected type %s' % type(variable))
    
def _slice_arrays(arrays, slicing):
    if isinstance(arrays, list) or isinstance(arrays, tuple):
        return [array[slicing] for array in arrays]
    else:
        return arrays[slicing]

ALL_AAS = 'ACDEFGHIKLMNPQRSTUVWXY'
ADDITIONAL_TOKENS = ['<OTHER>', '<START>', '<END>', '<PAD>']

# Each sequence is added <START> and <END> tokens
ADDED_TOKENS_PER_SEQ = 2

n_aas = len(ALL_AAS)
aa_to_token_index = {aa: i for i, aa in enumerate(ALL_AAS)}
additional_token_to_index = {token: i + n_aas for i, token in enumerate(ADDITIONAL_TOKENS)}
token_to_index = {**aa_to_token_index, **additional_token_to_index}
index_to_token = {index: token for token, index in token_to_index.items()}
n_tokens = len(token_to_index)

def tokenize_seq(seq):
    other_token_index = additional_token_to_index['<OTHER>']
    return [additional_token_to_index['<START>']] + [aa_to_token_index.get(aa, other_token_index) for aa in parse_seq(seq)] + \
            [additional_token_to_index['<END>']]
            
def parse_seq(seq):
    if isinstance(seq, str):
        return seq
    elif isinstance(seq, bytes):
        return seq.decode('utf8')
    else:
        raise TypeError('Unexpected sequence type: %s' % type(seq))


import sys
import os
import re
import gc
import importlib
from collections import defaultdict
from functools import reduce
from datetime import datetime, timedelta
import json

import numpy as np
import pandas as pd


### Logging ###

def log(*message, **kwargs):
    
    global _log_file
    
    end = kwargs.get('end', '\n')
    
    if len(message) == 1:
        message, = message
    
    full_message = '[%s] %s' % (format_now(), message)
    
    print(full_message, end = end)
    sys.stdout.flush()
    
    if log_file_open():
        _log_file.write(full_message + end)
        _log_file.flush()

def start_log(log_dir, log_file_base_name):
    
    global _log_file
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        
    log_file_name = '%s__%d__%s.txt' % (log_file_base_name, os.getpid(), format_now())
    
    if not log_file_open():
        print('Creating log file: %s' % log_file_name)
        _log_file = open(os.path.join(log_dir, log_file_name), 'w')
        
def close_log():
    
    global _log_file
    
    if log_file_open():
        _log_file.close()
        del _log_file
    
def restart_log():
    close_log()
    start_log()
    
def log_file_open():
    global _log_file
    return '_log_file' in globals()
    
def create_time_measure_if_verbose(opening_statement, verbose):
    if verbose:
        return TimeMeasure(opening_statement)
    else:
        return DummyContext()


### General ###

def get_nullable(value, default_value):
    if pd.isnull(value):
        return default_value
    else:
        return value
        
        
### Reflection ###

def load_object(full_object_name):
    name_parts = full_object_name.split('.')
    object_name = name_parts[-1]
    module_name = '.'.join(name_parts[:-1])
    module = importlib.import_module(module_name)
    return getattr(module, object_name)
    
    
### Strings ###

def trim(string, max_length, trim_suffix = '...'):
    if len(string) <= max_length:
        return string
    else:
        return string[:(max_length - len(trim_suffix))] + trim_suffix
        
def break_to_lines(text, max_line_len):
    
    lines = ['']
    
    for word in text.split():
        
        if len(lines[-1]) + len(word) > max_line_len:
            lines.append('')
            
        if lines[-1] != '':
            lines[-1] += ' '
            
        lines[-1] += word
        
    return '\n'.join(lines)
    
    
### IO ###

def safe_symlink(src, dst, post_creation_hook = lambda created_symlink: None):
    if os.path.exists(dst):
        log('%s: already exists.' % dst)
    else:
        try:
            os.symlink(src, dst)
            post_creation_hook(dst)
            log('Created link: %s -> %s' % (src, dst))
        except OSError as e:
            if e.errno == 17:
                log('%s: already exists after all.' % dst)
            else:
                raise e
        
def safe_mkdir(path):
    try:
        os.mkdir(path)
    except OSError as e:
        assert 'File exists' in str(e), str(e)
        
def format_size_in_bytes(size):
    
    UNIT_RATIO = 1024
    UNITS = ['B', 'KB', 'MB', 'GB', 'TB']
    
    for unit_index in range(len(UNITS)):
        if size < UNIT_RATIO:
            break
        else:
            size /= UNIT_RATIO
            
    return '%.1f%s' % (size, UNITS[unit_index])
    
def get_recognized_files_in_dir(dir_path, file_parser, log_unrecognized_files = True):
    
    recognized_files = []
    unrecognized_files = []
    
    for file_name in os.listdir(dir_path):
        try:
            recognized_files.append((file_parser(file_name), file_name))
        except:
            if log_unrecognized_files:
                unrecognized_files.append(file_name)
                
    if log_unrecognized_files and len(unrecognized_files) > 0:
        log('%s: %d unrecognized files: %s' % (dir_path, len(unrecognized_files), ', '.join(unrecognized_files)))
        
    return list(sorted(recognized_files))
    
def monitor_memory(min_bytes_to_log = 1e08, max_elements_to_check = 100, collect_gc = True, del_output_variables = True, \
        list_like_types = [list, tuple, np.ndarray, pd.Series], dict_like_types = [dict, defaultdict]):
    
    already_monitored_object_ids = set()
    
    def _is_of_any_type(obj, types):
        
        for t in types:
            if isinstance(obj, t):
                return True
            
        return False
        
    def _check_len_limit(obj):
        try:
            return len(obj) <= max_elements_to_check
        except:
            return False

    def _log_object_if_needed(name, obj):

        size = sys.getsizeof(obj)

        if size >= min_bytes_to_log:
            log('%s: %s' % (name, format_size_in_bytes(size)))
            
    def _monitor_object(name, obj):
        if id(obj) not in already_monitored_object_ids:
            
            already_monitored_object_ids.add(id(obj))
            _log_object_if_needed(name, obj)
                        
            if _is_of_any_type(obj, list_like_types) and _check_len_limit(obj):
                for i, element in enumerate(obj):
                    _monitor_object('%s[%d]' % (name, i), element)

            if _is_of_any_type(obj, dict_like_types) and _check_len_limit(obj):
                for key, value in obj.items():
                    _monitor_object('%s[%s]' % (name, repr(key)), value)
            
            
    for module_name, module in sys.modules.items():
        for variable_name in dir(module):
            
            full_variable_name = variable_name if module_name == '__main__' else '%s.%s' % (module_name, variable_name)
            _monitor_object(full_variable_name, getattr(module, variable_name))

            if del_output_variables and module_name == '__main__' and re.match(r'^_[\d_]+$', variable_name):
                delattr(module, variable_name)
                
    if del_output_variables:
        sys.modules['__main__'].Out = dict()
        sys.modules['__main__']._oh = dict()

    if collect_gc:
        gc.collect()


### Date & time ###

def format_now():
    return datetime.now().strftime('%Y_%m_%d-%H:%M:%S')
    

### Iterators & collections ###

def compare_list_against_collection(input_list, collection):
    collection_set = set(collection)
    return [element for element in input_list if element in collection_set], [element for element in input_list if element not in collection_set]

def get_chunk_slice(size, n_chunks, chunk_index):
    assert size >= n_chunks
    chunk_size = size / n_chunks
    start_index = int(chunk_index * chunk_size)
    end_index = int((chunk_index + 1) * chunk_size)
    return start_index, end_index

def get_chunk_intervals(size, chunk_size):
    for start_index in range(0, size, chunk_size):
        end_index = min(start_index + chunk_size, size)
        yield start_index, end_index
        
def to_chunks(iterable, chunk_size):
    
    chunk = []
    
    for element in iterable:
        
        chunk.append(element)
        
        if len(chunk) >= chunk_size:
            yield chunk
            chunk = []
            
    if len(chunk) > 0:
        yield chunk
        
def get_job_and_subjob_indices(n_jobs, n_tasks, task_index):
    
    '''
    For example, if there are 170 tasks for working on 50 jobs, than each job will be divided to 3-4 tasks.
    Since 170 % 50 = 20, the 20 first jobs will receive 4 tasks and the last 30 jobs will receive only 3 tasks.
    In total, the first 80 tasks will be dedicated to jobs with 4 tasks each, and the 90 last tasks will be
    dedicated to jobs with 3 tasks each. Hence, tasks 0-3 will go to job 0, tasks 4-7 will go to job 1, and so on;
    tasks 80-82 will go to job 21, tasks 83-85 will job to job 22, and so on.  
    '''
    
    assert n_tasks >= n_jobs
    n_tasks_in_unprivileged_jobs = n_tasks // n_jobs
    n_tasks_in_privileged_jobs = n_tasks_in_unprivileged_jobs + 1
    n_privileged_jobs = n_tasks % n_jobs
    n_tasks_of_privileged_jobs = n_tasks_in_privileged_jobs * n_privileged_jobs
    
    if task_index < n_tasks_of_privileged_jobs:
        job_index = task_index // n_tasks_in_privileged_jobs
        index_within_job = task_index % n_tasks_in_privileged_jobs
        n_tasks_in_job = n_tasks_in_privileged_jobs
    else:
        task_index_in_unprivileged_group = task_index - n_tasks_of_privileged_jobs
        job_index = n_privileged_jobs + task_index_in_unprivileged_group // n_tasks_in_unprivileged_jobs
        index_within_job = task_index_in_unprivileged_group % n_tasks_in_unprivileged_jobs
        n_tasks_in_job = n_tasks_in_unprivileged_jobs
        
    return job_index, index_within_job, n_tasks_in_job
    
def choose_from_cartesian_product(list_of_values, i, total = None):
    
    n = int(np.prod(list(map(len, list_of_values))))
    
    if total is not None:
        assert n == total
    
    chosen_elements = []
    
    for values in list_of_values:
        n //= len(values)
        chosen_elements.append(values[i // n])
        i %= n

    return chosen_elements

def calc_overlap_between_segments(ordered_segments1, ordered_segments2):
    
    '''
    Calculates the total overlap size between a pair of ordered and disjoint groups of segments.
    Each group of segment is given by: [(start1, end1), (start2, end2), ...]. 
    '''
    
    from interval_tree import IntervalTree
    
    if len(ordered_segments1) == 0 or len(ordered_segments2) == 0:
        return 0
    
    if len(ordered_segments1) > len(ordered_segments2):
        ordered_segments1, ordered_segments2 = ordered_segments2, ordered_segments1
    
    min_value = min(ordered_segments1[0][0], ordered_segments2[0][0])
    max_value = max(ordered_segments1[-1][1], ordered_segments2[-1][1])
    interval_tree1 = IntervalTree([segment + (segment,) for segment in ordered_segments1], min_value, max_value)
    total_overlap = 0
    
    for segment in ordered_segments2:
        for overlapping_segment in interval_tree1.find_range(segment):
            overlapping_start = max(segment[0], overlapping_segment[0])
            overlapping_end = min(segment[1], overlapping_segment[1])
            assert overlapping_start <= overlapping_end, 'Reported overlap between %d..%d to %d..%d.' % (segment + \
                    overlapping_segment)
            total_overlap += (overlapping_end - overlapping_start + 1)
            
    return total_overlap
    
def merge_lists_with_compatible_relative_order(lists):
    
    '''
    Given a list of lists with compatible relative ordering (i.e. for every two sublists, the subset of elements that exist in the two
    sublists will have the same relative order), returns a merging of these sublists into a single grand list that contains all the
    elements (each element only once), and preserves the same ordering.
    '''
    
    def merge_two_sublists(list1, list2):
        
        value_to_index = {value: float(i) for i, value in enumerate(list1)}
        unique_list2_index = {}
        last_identified_index = len(list1)
        
        for i, value in list(enumerate(list2))[::-1]:
            if value in value_to_index:
                last_identified_index = value_to_index[value]
            else:
                unique_list2_index[value] = last_identified_index - 1 + i / len(list2)
                
        value_to_index.update(unique_list2_index)
        return sorted(value_to_index.keys(), key = value_to_index.get)
    
    return reduce(merge_two_sublists, lists, [])
    
    
### argparse ###

def get_parser_bool_type(parser):

    def _bool_type(value):
        if isinstance(value, bool):
           return value
        if value.lower() in ['yes', 'true', 't', 'y', '1']:
            return True
        elif value.lower() in ['no', 'false', 'f', 'n', '0']:
            return False
        else:
            raise parser.error('"%s": unrecognized boolean value.' % value)
            
    return _bool_type

def get_parser_file_type(parser, must_exist = False):

    def _file_type(path):
    
        path = os.path.expanduser(path)
    
        if must_exist:
            if not os.path.exists(path):
                parser.error('File doesn\'t exist: %s' % path)
            elif not os.path.isfile(path):
                parser.error('Not a file: %s' % path)
            else:
                return path
        else:
        
            dir_path = os.path.dirname(path)
        
            if dir_path and not os.path.exists(dir_path):
                parser.error('Parent directory doesn\'t exist: %s' % dir_path)
            else:
                return path
    
    return _file_type

def get_parser_directory_type(parser, create_if_not_exists = False):
    
    def _directory_type(path):
    
        path = os.path.expanduser(path)
    
        if not os.path.exists(path):
            if create_if_not_exists:
            
                parent_path = os.path.dirname(path)
            
                if parent_path and not os.path.exists(parent_path):
                    parser.error('Cannot create empty directory (parent directory doesn\'t exist): %s' % path)
                else:
                    os.mkdir(path)
                    return path
            else:
                parser.error('Path doesn\'t exist: %s' % path)
        elif not os.path.isdir(path):
            parser.error('Not a directory: %s' % path)
        else:
            return path
        
    return _directory_type
    
def add_parser_task_arguments(parser):
    parser.add_argument('--task-index', dest = 'task_index', metavar = '<0,...,N_TASKS-1>', type = int, default = None, help = 'If you want to ' + \
            ' distribute this process across multiple computation resources (e.g. on a cluster) you can specify the total number of tasks ' + \
            '(--total-tasks) to split it into, and the index of the current task to run (--task-index).')
    parser.add_argument('--total-tasks', dest = 'total_tasks', metavar = '<N_TASKS>', type = int, default = None, help = 'See --task-index.')
    parser.add_argument('--task-index-env-variable', dest = 'task_index_env_variable', metavar = '<e.g. SLURM_ARRAY_TASK_ID>', type = str, default = None, \
            help = 'Instead of specifying a hardcoded --task-index, you can specify an environtment variable to take it from (e.g. SLURM_ARRAY_TASK_ID ' + \
            'if you use SLURM to distribute the jobs).')
    parser.add_argument('--total-tasks-env-variable', dest = 'total_tasks_env_variable', metavar = '<e.g. SLURM_ARRAY_TASK_COUNT>', type = str, \
            default = None, help = 'Instead of specifying a hardcoded --total-tasks, you can specify an environtment variable to take it from (e.g. ' + \
            'SLURM_ARRAY_TASK_COUNT if you use SLURM to distribute the jobs).')
            
def determine_parser_task_details(args):
    
    if args.task_index is not None and args.task_index_env_variable is not None:
        parser.error('You must choose between --task-index and --task-index-env-variable.')
    
    if args.task_index is not None:
        task_index = args.task_index
    elif args.task_index_env_variable is not None:
        task_index = int(os.getenv(args.task_index_env_variable))
    else:
        task_index = None
        
    if args.total_tasks is not None and args.total_tasks_env_variable is not None:
        parser.error('You must choose between --total-tasks and --total-tasks-env-variable.')
        
    if args.total_tasks is not None:
        total_tasks = args.total_tasks
    elif args.total_tasks_env_variable is not None:
        total_tasks = int(os.getenv(args.total_tasks_env_variable))
    else:
        total_tasks = None

    if task_index is None and total_tasks is None:
        task_index = 0
        total_tasks = 1
    elif task_index is None or total_tasks is None:
        parser.error('Task index and total tasks must either be specified or unspecified together.')
    
    if task_index < 0 or task_index >= total_tasks:
        parser.error('Task index must be in the range 0,...,(total tasks)-1.')
    
    return task_index, total_tasks

    
### Numpy ###

def normalize(x):

    if isinstance(x, list):
        x = np.array(x)

    u = np.mean(x)
    sigma = np.std(x)
    
    if sigma == 0:
        return np.ones_like(x)
    else:
        return (x - u) / sigma
    
def random_mask(size, n_trues):
    assert n_trues <= size
    mask = np.full(size, False)
    mask[:n_trues] = True
    np.random.shuffle(mask)
    return mask
    
def indices_to_masks(n, indices):
    positive_mask = np.zeros(n, dtype = bool)
    positive_mask[indices] = True
    negative_mask = np.ones(n, dtype = bool)
    negative_mask[indices] = False
    return positive_mask, negative_mask
    
def as_hot_encoding(values, value_to_index, n_values = None):

    if n_values is None:
        n_values = len(value_to_index)
        
    result = np.zeros(n_values)
    
    try:
        values = iter(values)
    except TypeError:
        values = iter([values])
        
    for value in values:
        result[value_to_index[value]] += 1
        
def is_full_rank(matrix):
    return np.linalg.matrix_rank(matrix) == min(matrix.shape)
    
def find_linearly_independent_columns(matrix):
    
    '''
    The calculation is fasciliated by the Gram Schmidt process, everytime taking the next column and removing its projections
    from all next columns, getting rid of columns which end up zero.
    '''
    
    n_rows, n_cols = matrix.shape
    
    if np.linalg.matrix_rank(matrix) == n_cols:
        return np.arange(n_cols)
    
    orthogonalized_matrix = matrix.copy().astype(float)
    independent_columns = []
    
    for i in range(n_cols):
        if not np.isclose(orthogonalized_matrix[:, i], 0).all():
            
            independent_columns.append(i)
            
            if len(independent_columns) >= n_rows:
                break
            
            orthogonalized_matrix[:, i] = orthogonalized_matrix[:, i] / np.linalg.norm(orthogonalized_matrix[:, i])
            
            if i < n_cols - 1:
                # Remove the projection of the ith column from all next columns
                orthogonalized_matrix[:, (i + 1):] -= np.dot(orthogonalized_matrix[:, i], \
                        orthogonalized_matrix[:, (i + 1):]).reshape(1, -1) * orthogonalized_matrix[:, i].reshape(-1, 1)
            
    return np.array(independent_columns)

def transpose_dataset(src, dst, max_memory_bytes, flush_func = None):
    
    n_rows, n_cols = src.shape[:2]
    entry_nbytes = src[:1, :1].nbytes
    ideal_entries_per_chunk = max_memory_bytes / entry_nbytes
    ideal_chunk_size = np.sqrt(ideal_entries_per_chunk)
    
    if n_rows <= n_cols:
        row_chunk_size = min(int(ideal_chunk_size), n_rows)
        col_chunk_size = min(int(ideal_entries_per_chunk / row_chunk_size), n_cols)
    else:
        col_chunk_size = min(int(ideal_chunk_size), n_cols)
        row_chunk_size = min(int(ideal_entries_per_chunk / col_chunk_size), n_rows)
        
    log('Will use chunks of size %dx%d to transpose a %dx%d matrix.' % (row_chunk_size, col_chunk_size, n_rows, n_cols))
    
    for row_start, row_end in get_chunk_intervals(n_rows, row_chunk_size):
        for col_start, col_end in get_chunk_intervals(n_cols, col_chunk_size):
            
            log('Transposing chunk (%d..%d)x(%d..%d)...' % (row_start, row_end - 1, col_start, col_end - 1))
            dst[col_start:col_end, row_start:row_end] = src[row_start:row_end, col_start:col_end].transpose()
            
            if flush_func is not None:
                flush_func()
                
    log('Finished transposing.')


### Pandas ###

def summarize(df, n = 5, sample = False):
    
    from IPython.display import display
    
    if sample:
        display(df.sample(n))
    else:
        display(df.head(n))
    
    print('%d records' % len(df))
    
def nullable_idxmin(series):
    
    result = series.idxmin()
    
    if pd.isnull(result):
        if len(series) == 0:
            return np.nan
        else:
            return series.index[0]
    else:
        return result
    
def get_first_value(df):
    '''
    Will return a Series with the same index. For each row the value will be that of the first column which is not null.
    '''
    col_idxs = np.argmax(pd.notnull(df).values, axis = 1)
    return pd.Series(df.values[np.arange(len(df)), col_idxs], index = df.index)
    
def slice_not_in_index(df_or_series, index_to_exclude):
    mask = pd.Series(True, index = df_or_series.index)
    mask.loc[index_to_exclude] = False
    return df_or_series.loc[mask]
    
def swap_series_index_and_value(series):
    return pd.Series(series.index, index = series.values)
    
def concat_dfs_with_partial_columns(dfs):
    columns = max([df.columns for df in dfs], key = len)
    assert all([set(df.columns) <= set(columns) for df in dfs])
    return pd.concat(dfs, sort = False)[columns]
    
def concat_dfs_with_compatible_columns(dfs):
    columns = merge_lists_with_compatible_relative_order([df.columns for df in dfs])
    return pd.concat(dfs, sort = False)[columns]

def safe_get_df_group(df_groupby, group_name):
    if group_name in df_groupby.groups:
        return df_groupby.get_group(group_name)
    else:
        _, some_group_df = next(iter(df_groupby))
        return pd.DataFrame(columns = some_group_df.columns)
        
def bin_groupby(df, series_or_col_name, n_bins):
    
    if len(df) == 0:
        return df
    
    if isinstance(series_or_col_name, str):
        series = df[series_or_col_name]
    else:
        series = series_or_col_name
        
    min_value, max_value = series.min(), series.max()
    bin_size = (max_value - min_value) / n_bins
    
    bind_ids = ((series - min_value) / bin_size).astype(int)
    bind_ids[bind_ids >= n_bins] = n_bins - 1
    
    return df.groupby(bind_ids)
    
def value_df_to_hot_encoding_df(value_df, value_headers = {}):
    
    flat_values = value_df.values.flatten()
    all_values = sorted(np.unique(flat_values[pd.notnull(flat_values)]))
    value_to_index = {value: i for i, value in enumerate(all_values)}
    hot_encoding_matrix = np.zeros((len(value_df), len(all_values)))
    
    for _, column_values in value_df.iteritems():
        row_position_to_value_index = column_values.reset_index(drop = True).dropna().map(value_to_index)
        hot_encoding_matrix[row_position_to_value_index.index.values, row_position_to_value_index.values] = 1
    
    headers = [value_headers.get(value, value) for value in all_values]
    return pd.DataFrame(hot_encoding_matrix, index = value_df.index, columns = headers)
    
def set_series_to_hot_encoding_df(set_series, value_headers = {}):

    all_values = sorted(set.union(*set_series))
    value_to_index = {value: i for i, value in enumerate(all_values)}
    hot_encoding_matrix = np.zeros((len(set_series), len(all_values)))
    
    for i, record_values in enumerate(set_series):
        hot_encoding_matrix[i, [value_to_index[value] for value in record_values]] = 1
        
    headers = [value_headers.get(value, value) for value in all_values]
    return pd.DataFrame(hot_encoding_matrix, index = set_series.index, columns = headers)
    
def resolve_dummy_variable_trap(hot_encoding_df, validate_completeness = True, inplace = False, verbose = True):

    '''
    When using one-hot-encoding in regression, there is a problem of encoding all possible variables if also using an intercept/const variable,
    because then the variables end up linearly dependent (a singular matrix is problematic with many implementations of regression). See for
    example: https://www.algosome.com/articles/dummy-variable-trap-regression.html
    To resolve this issue, this function will remove the most frequent column (to minimize the chance of any subset of the rows resulting a
    matrix which is not fully ranked).
    '''
    
    # Validate we are indeed dealing with one-hot-encoding.
    assert set(np.unique(hot_encoding_df.values).astype(float)) <= {0.0, 1.0}
    
    if validate_completeness:
        assert (hot_encoding_df.sum(axis = 1) == 1).all()
    else:
        assert (hot_encoding_df.sum(axis = 1) <= 1).all()
    
    most_frequent_variable = hot_encoding_df.sum().idxmax()
    
    if verbose:
        log('To avoid the "dummy variable trap", removing the %s column (%d matching records).' % (most_frequent_variable, \
                hot_encoding_df[most_frequent_variable].sum()))
    
    if inplace:
        del hot_encoding_df[most_frequent_variable]
    else:
        return hot_encoding_df[[column_name for column_name in hot_encoding_df.columns if column_name != most_frequent_variable]]
    
def set_constant_row(df, row_mask, row_values):
    df[row_mask] = np.tile(row_values, (row_mask.sum(), 1))
    
def construct_df_from_rows(row_repertoire, row_indexer):
    
    result = pd.DataFrame(index = row_indexer.index, columns = row_repertoire.columns)
    
    for row_index, row_values in row_repertoire.iterrows():
        set_constant_row(result, row_indexer == row_index, row_values)
        
    return result
    
def get_row_last_values(df):
    
    result = pd.Series(np.nan, index = df.index)

    for column in df.columns[::-1]:
        result = result.where(pd.notnull(result), df[column])

    return result
    
def are_close_dfs(df1, df2, rtol = 1e-05, atol = 1e-08):
    
    assert (df1.dtypes == df2.dtypes).all()
    
    for column, dtype in df1.dtypes.iteritems():
        
        if np.issubdtype(dtype, np.float):
            cmp_series = np.isclose(df1[column], df2[column], rtol = rtol, atol = atol) | (pd.isnull(df1[column]) & \
                    pd.isnull(df2[column]))
        else:
            cmp_series = (df1[column] == df2[column])
            
        if not cmp_series.all():
            return False
        
    return True
    
def append_df_to_excel(excel_writer, df, sheet_name, index = True):
    
    header_format = excel_writer.book.add_format({'bold': True})
     
    df.to_excel(excel_writer, sheet_name, index = index)
    worksheet = excel_writer.sheets[sheet_name]
    
    for column_index, column_name in enumerate(df.columns):
        worksheet.write(0, column_index + int(index), column_name, header_format)
        
    if index:
        for row_index_number, row_index_value in enumerate(df.index):
            worksheet.write(row_index_number + 1, 0, row_index_value)
        
def is_binary_series(series):

    # First validating that the type of the series is convertable to float.
    try:
        float(series.iloc[0])
    except TypeError:
        return False

    return set(series.unique().astype(float)) <= {0.0, 1.0}
    
def resolve_quasi_complete_separation_by_removing_binary_columns(X, y):
    
    '''
    When performing logistic regression of y against X, the matrix X must be of full rank; otherwise (i.e. if the columns of X are
    linearly dependent), then statsmodel's Logit model gives a singular-matrix error. It also appears that quasi-complete separation
    causes troubles, namely if the columns of X are linearly dependent conditioned on y. In other words, assuming that y is binary,
    we need that X[y, :] would still be of full rank (we assume that the vast majority of records have a negative y value, and only
    a small fraction have a positive value, so given that X is of full rank we need not worry about X[~y, :]). To resolve this problem,
    this function will remove binary columns of X until X[y, :] is of full rank. Whenever a column of X is removed, we also remove the
    corresponding records (rows of X and y) that have those values (so if a removed column represent some covariate, e.g. a certain
    batch, we also remove all the samples from this batch in order for not having any covariates not accounted for).
    @param X (pd.DataFrame): The exogenous variables (rows are records, columns are variables).
    @pram y (pd.Series): The endogenous variable (must have the same index as X).
    '''
    
    row_mask = pd.Series(True, index = X.index)
    
    if not is_binary_series(y):
        return X, y, X.columns, set(), row_mask
        
    boolean_y = y.astype(bool)
    all_kept_binary_columns = np.array([column_name for column_name in X.columns if is_binary_series(X[column_name])])
    # We sort the binary columns by how common they are, so when we start removing them, we will give priority to the more common ones
    # (i.e. remove the least frequent first).
    all_kept_binary_columns = X[all_kept_binary_columns].sum().sort_values(ascending = False).index
    all_removed_binary_columns = set()
    
    while len(all_kept_binary_columns) > 0:
    
        positive_X = X.loc[row_mask & boolean_y, all_kept_binary_columns]
        old_all_kept_binary_columns = all_kept_binary_columns
        all_kept_binary_columns = all_kept_binary_columns[find_linearly_independent_columns(positive_X.values)]
        columns_to_remove = set(old_all_kept_binary_columns) - set(all_kept_binary_columns)
        
        for column_name in columns_to_remove:
            log('Removing the columns %s (%d occurances) to avoid quasi-complete separation.' % (column_name, X[column_name].sum()))
            all_removed_binary_columns.add(column_name)
            row_mask &= (~X[column_name].astype(bool))
            
        if len(columns_to_remove) == 0:
            break

    if not row_mask.all():
        log('Overall removed %d columns occuring in %d records to avoid quasi-complete separation.' % (len(all_removed_binary_columns), \
                (~row_mask).sum()))
        
    retained_columns = [column_name for column_name in X.columns if column_name not in all_removed_binary_columns]
    X = X.loc[row_mask, retained_columns]
    y = y.loc[row_mask]
    
    return X, y, retained_columns, all_removed_binary_columns, row_mask

    
### Statistics ###

def to_normal_z_values(raw_values):

    from scipy.stats import rankdata, norm
    
    pvals = (rankdata(raw_values) - 0.5) / len(raw_values)
    normal_z_values = norm.ppf(pvals)
    
    if isinstance(raw_values, pd.Series):
        return pd.Series(normal_z_values, index = raw_values.index)
    else:
        return normal_z_values

def multipletests_with_nulls(values, method = 'fdr_bh'):

    from statsmodels.stats.multitest import multipletests
    
    significance = np.zeros(len(values), dtype = bool)
    qvals = np.nan * np.empty(len(values))
    mask = pd.notnull(values)
    
    if mask.any():
        significance[np.array(mask)], qvals[np.array(mask)], _, _ = multipletests(values[mask], method = method)
    
    return significance, qvals
    
def test_enrichment(mask1, mask2):

    from scipy.stats import fisher_exact

    assert len(mask1) == len(mask2)
    
    n1 = mask1.sum()
    n2 = mask2.sum()
    n_both = (mask1 & mask2).sum()
    n_total = len(mask1)
    n_expected = n1 * n2 / n_total
    enrichment_factor = n_both / n_expected
    
    contingency_table = np.array([
        [(mask1 & mask2).sum(), (mask1 & (~mask2)).sum()],
        [((~mask1) & mask2).sum(), ((~mask1) & (~mask2)).sum()],
    ])
    _, pval = fisher_exact(contingency_table)
    
    return n1, n2, n_both, n_total, n_expected, enrichment_factor, contingency_table, pval
    
def test_enrichment_sets(set1, set2, n_total):

    from scipy.stats import fisher_exact
    
    n1 = len(set1)
    n2 = len(set2)
    n_both = len(set1 & set2)
    n_expected = n1 * n2 / n_total
    enrichment_factor = n_both / n_expected
    
    contingency_table = np.array([
        [n_both, n1 - n_both],
        [n2 - n_both, n_total - n1 - n2 + n_both],
    ])
    _, pval = fisher_exact(contingency_table)
    
    return n1, n2, n_both, n_total, n_expected, enrichment_factor, contingency_table, pval
    
    
### h5f ###
    
def flush_h5_file(h5f):
    h5f.flush()
    os.fsync(h5f.id.get_vfd_handle())
    
def transpose_h5f_dataset(h5f, src_name, dst_name, max_memory_bytes):
    flush_func = lambda: flush_h5_file(h5f)
    src = h5f[src_name]
    nrows, ncols = src.shape[:2]
    dst = h5f.create_dataset(dst_name, shape = (ncols, nrows), dtype = src.dtype)
    transpose_dataset(src, dst, max_memory_bytes, flush_func)
    
    
### Matplotlib ###

def draw_rectangle(ax, start_x, end_x, start_y, end_y, **kwargs):
    from matplotlib import patches
    ax.add_patch(patches.Rectangle((start_x, start_y), end_x - start_x, end_y - start_y, **kwargs))
    
def set_ax_border_color(ax, color):

    import matplotlib.pyplot as plt

    for child in ax.get_children():
        if isinstance(child, plt.matplotlib.spines.Spine):
            child.set_color(color)
    
def plot_prediction_scatter(y_pred, y_true, value = 'value'):

    import matplotlib.pyplot as plt
    
    log(pearsonr(y_pred, y_true))
    log(spearmanr(y_pred, y_true))

    fig, ax = plt.subplots(figsize = (10, 6))
    ax.scatter(y_pred, y_true)
    ax.set_xlabel('Predicted %s' % value)
    ax.set_ylabel('Actual %s' % value)
    
def draw_pvals_qq_plot(pvals, max_density = 100, min_pval = None, ax = None, figsize = (7, 7), scatter_options = {}, \
        xlabel = 'Expected p-values (-log10)', ylabel = 'Observed p-values (-log10)'):
    
    import matplotlib.pyplot as plt
    
    if 'color' not in scatter_options:
        scatter_options['color'] = '#2e75b6'
    
    pvals = np.array(pvals)
    
    if min_pval is not None:
        pvals = np.maximum(pvals, min_pval)
    
    n_total_pvals = len(pvals)
    sorted_mlog_pvals = np.sort(-np.log10(pvals))
    max_mlog_pval = sorted_mlog_pvals.max()
    
    if ax is None:
        _, ax = plt.subplots(figsize = figsize)
    
    ax.plot([0, max_mlog_pval], [0, max_mlog_pval], color = 'red', linestyle = '--', alpha = 0.5)
    ax.set_xlim((0, max_mlog_pval))
    ax.set_ylim((0, max_mlog_pval))
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    for upper_limit in range(1, int(max_mlog_pval + 3)):
        
        n_remained_pvals = len(sorted_mlog_pvals)
        i = np.searchsorted(sorted_mlog_pvals, upper_limit)
        range_pvals = sorted_mlog_pvals[:i]
        sorted_mlog_pvals = sorted_mlog_pvals[i:]
        
        if len(range_pvals) > 0:
                    
            if len(range_pvals) <= max_density:
                range_chosen_indices = np.arange(len(range_pvals))
            else:
                # We want to choose the p-values uniformly in the space of their expected frequencies (i.e. sampling more towards the higher end of the
                # spectrum).
                range_min_mlog_freq = -np.log10(n_remained_pvals / n_total_pvals)
                range_max_mlog_freq = -np.log10((n_remained_pvals - len(range_pvals) + 1) / n_total_pvals)
                range_chosen_mlog_freqs = np.linspace(range_min_mlog_freq, range_max_mlog_freq, max_density)
                range_chosen_freqs = np.power(10, -range_chosen_mlog_freqs)
                # Once having the desired freqs, reverse the function to get the indices that provide them
                range_chosen_indices = np.unique((n_remained_pvals - n_total_pvals * range_chosen_freqs).astype(int))

            range_pvals = range_pvals[range_chosen_indices]
            range_freqs = (n_remained_pvals - range_chosen_indices) / n_total_pvals
            range_mlog_freqs = -np.log10(range_freqs)
            ax.scatter(range_mlog_freqs, range_pvals, **scatter_options)
            
def draw_manhattan_plot(gwas_results, significance_treshold = 5e-08, max_results_to_plot = 1e06, \
        pval_threshold_to_force_inclusion = 1e-03, min_pval = 1e-300, ax = None, figsize = (12, 6), \
        s = 1.5, chrom_to_color = None):
    
    '''
    gwas_results (pd.DataFrame): Should have the following columns:
    - chromosome (str)
    - position (int)
    - pval (float)
    '''
    
    import matplotlib.pyplot as plt
        
    CHROMS = list(map(str, range(1, 23))) + ['X', 'Y']
    CHROM_TO_COLOR = {'1': '#0100fb', '2': '#ffff00', '3': '#00ff03', '4': '#bfbfbf', '5': '#acdae9', '6': '#a020f1',
            '7': '#ffa502', '8': '#ff00fe', '9': '#fe0000', '10': '#90ee90', '11': '#a52929', '12': '#000000', 
            '13': '#ffbfcf', '14': '#4484b2', '15': '#b63063', '16': '#f8816f', '17': '#ed84f3', '18': '#006401',
            '19': '#020184', '20': '#ced000', '21': '#cd0001', '22': '#050098', 'X': '#505050', 'Y': '#ff8000'}
    
    if chrom_to_color is None:
        chrom_to_color = CHROM_TO_COLOR
    
    if len(gwas_results) > max_results_to_plot:
        mask = pd.Series(random_mask(len(gwas_results), int(max_results_to_plot)), index = gwas_results.index)
        mask[gwas_results['pval'] <= pval_threshold_to_force_inclusion] = True
        gwas_results = gwas_results[mask]
    
    max_pos_per_chrom = gwas_results.groupby('chromosome')['position'].max()
    accumulating_pos = 0
    chrom_accumulating_positions = []
    
    for chrom in CHROMS:
        if chrom in max_pos_per_chrom.index:
            chrom_accumulating_positions.append((chrom, accumulating_pos + 1, accumulating_pos + max_pos_per_chrom[chrom]))
            accumulating_pos += max_pos_per_chrom[chrom]
            
    chrom_accumulating_positions = pd.DataFrame(chrom_accumulating_positions, columns = ['chromosome', \
            'accumulating_start_position', 'accumulating_end_position']).set_index('chromosome', drop = True)
    chrom_middle_accumulating_positions = (chrom_accumulating_positions['accumulating_start_position'] + \
            chrom_accumulating_positions['accumulating_end_position']) / 2
        
    if ax is None:
        _, ax = plt.subplots(figsize = figsize)
    
    ax.set_facecolor('white')
    plt.setp(ax.spines.values(), color = '#444444')
    ax.grid(False)
    
    if significance_treshold is not None:
        ax.axhline(y = -np.log10(significance_treshold), linestyle = '--', linewidth = 1, color = 'red')
    
    gwas_results_per_chrom = gwas_results.groupby('chromosome')
    max_y = 0
    
    for chrom in chrom_accumulating_positions.index:
        chrom_gwas_results = gwas_results_per_chrom.get_group(chrom)
        chrom_gwas_accumulating_positions = chrom_accumulating_positions.loc[chrom, 'accumulating_start_position'] + \
                chrom_gwas_results['position']
        chrom_gwas_minus_log_pval = -np.log10(np.maximum(chrom_gwas_results['pval'], min_pval))
        max_y = max(max_y, chrom_gwas_minus_log_pval.max())
        ax.scatter(chrom_gwas_accumulating_positions, chrom_gwas_minus_log_pval, color = chrom_to_color[chrom], s = s)
        
    ax.set_xlabel('Chromosome')
    ax.set_ylabel('-log10(p-value)')
    ax.set_xticks(chrom_middle_accumulating_positions)
    ax.set_xticklabels(chrom_middle_accumulating_positions.index)
    ax.set_xlim(1, accumulating_pos)
    ax.set_ylim(0, max_y + 1)
    
    return ax
    
    
### Biopython Helper Functions ###

def as_biopython_seq(seq):

    from Bio.Seq import Seq

    if isinstance(seq, Seq):
        return seq
    elif isinstance(seq, str):
        return Seq(seq)
    else:
        raise Exception('Cannot resolve type %s as Biopython Seq' % type(seq))
            
            
### Slurm ###

def get_slurm_job_array_ids(parse_total_tasks_by_max_variable = True, log_ids = True, verbose = True, task_index_remapping_json_file_path = None):

    job_id = int(os.getenv('SLURM_ARRAY_JOB_ID'))
    task_index = int(os.getenv('SLURM_ARRAY_TASK_ID'))
    
    if 'TASK_ID_OFFSET' in os.environ:
        
        task_offset = int(os.getenv('TASK_ID_OFFSET'))
        
        if verbose:
            log('Raw task index %d with offset %d.' % (task_index, task_offset))
        
        task_index += task_offset
        
    if task_index_remapping_json_file_path is not None:
        
        with open(task_index_remapping_json_file_path, 'r') as f:
            task_index_remapping = json.load(f)
            
        remapped_task_index = task_index_remapping[task_index]
        
        if verbose:
            log('Remapped task index %d into %d.' % (task_index, remapped_task_index))
        
        task_index = remapped_task_index
    
    if 'TOTAL_TASKS' in os.environ:
        total_tasks = int(os.getenv('TOTAL_TASKS'))
    elif parse_total_tasks_by_max_variable:
        total_tasks = int(os.getenv('SLURM_ARRAY_TASK_MAX')) + 1
    else:
        total_tasks = int(os.getenv('SLURM_ARRAY_TASK_COUNT'))
    
    if log_ids:
        log('Running job %s, task %d of %d.' % (job_id, task_index, total_tasks))
    
    return job_id, total_tasks, task_index 


### Liftover ###

def liftover_locus(liftover, chrom, pos):
    try:
            
        pos = int(pos)

        if not isinstance(chrom, str) or not chrom.startswith('chr'):
            chrom = 'chr%s' % chrom

        (new_chrom, new_pos, _, _), = liftover.convert_coordinate(chrom, pos)

        if new_chrom.startswith('chr'):
            new_chrom = new_chrom[3:]

        return new_chrom, new_pos
    except:
        return np.nan, np.nan

def liftover_loci_in_df(df, chrom_column = 'chromosome', pos_column = 'position', source_ref_genome = 'hg38', \
        target_ref_genome = 'hg19'):
    
    from pyliftover import LiftOver
    
    liftover = LiftOver(source_ref_genome, target_ref_genome)
    new_loci = []
    
    for _, (chrom, pos) in df[[chrom_column, pos_column]].iterrows():
        new_loci.append(liftover_locus(liftover, chrom, pos))
            
    new_chroms, new_positions = (pd.Series(list(values), index = df.index) for values in zip(*new_loci))
    return pd.concat([new_chroms.rename(chrom_column) if column == chrom_column else (new_positions.rename(pos_column) if \
            column == pos_column else df[column]) for column in df.columns], axis = 1)    
    
    
### Helper classes ###

class DummyContext(object):

    def __enter__(self):
        pass
        
    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass

class TimeMeasure(object):

    def __init__(self, opening_statement):
        self.opening_statement = opening_statement

    def __enter__(self):
        self.start_time = datetime.now()
        log(self.opening_statement)

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.finish_time = datetime.now()
        self.elapsed_time = self.finish_time - self.start_time
        log('Finished after %s.' % self.elapsed_time)
        
class Profiler(object):

    def __init__(self):
        self.creation_time = datetime.now()
        self.profiles = defaultdict(Profiler.Profile)
        
    def measure(self, profile_name):
        return self.profiles[profile_name].measure()
        
    def format(self, delimiter = '\n'):
        all_profiles = list(self.profiles.items()) + [('Total', Profiler.Profile(total_invokes = 1, total_time = datetime.now() - self.creation_time))]
        sorted_profiles = sorted(all_profiles, key = lambda profile_tuple: profile_tuple[1].total_time, reverse = True)
        return delimiter.join(['%s: %s' % (profile_name, profile) for profile_name, profile in sorted_profiles])
        
    def __repr__(self):
        return self.format()
        
    class Profile(object):
    
        def __init__(self, total_invokes = 0, total_time = timedelta(0)):
            self.total_invokes = total_invokes
            self.total_time = total_time
            
        def measure(self):
            return Profiler._Measurement(self)
            
        def __repr__(self):
            return '%s (%d times)' % (self.total_time, self.total_invokes)
        
    class _Measurement(object):
    
        def __init__(self, profile):
            self.profile = profile
        
        def __enter__(self):
            self.start_time = datetime.now()
            
        def __exit__(self, exc_type, exc_value, exc_traceback):
            self.profile.total_time += (datetime.now() - self.start_time)
            self.profile.total_invokes += 1


class OutputType:
    
    def __init__(self, is_seq, output_type):
        self.is_seq = is_seq
        self.output_type = output_type
        self.is_numeric = (output_type == 'numeric')
        self.is_binary = (output_type == 'binary')
        self.is_categorical = (output_type == 'categorical')
        
    def __str__(self):
        if self.is_seq:
            return '%s sequence' % self.output_type
        else:
            return 'global %s' % self.output_type
            
class OutputSpec:

    def __init__(self, output_type, unique_labels = None):
        
        if output_type.is_numeric:
            assert unique_labels is None
        elif output_type.is_binary:
            if unique_labels is None:
                unique_labels = [0, 1]
            else:
                assert unique_labels == [0, 1]
        elif output_type.is_categorical:
            assert unique_labels is not None
        else:
            raise ValueError('Unexpected output type: %s' % output_type)
        
        self.output_type = output_type
        self.unique_labels = unique_labels
        
        if unique_labels is not None:
            self.n_unique_labels = len(unique_labels)
            
def finetune(model_generator, input_encoder, output_spec, train_seqs, train_raw_Y, valid_seqs = None, valid_raw_Y = None, seq_len = 512, batch_size = 32, \
        max_epochs_per_stage = 40, lr = None, begin_with_frozen_pretrained_layers = True, lr_with_frozen_pretrained_layers = None, n_final_epochs = 1, \
        final_seq_len = 1024, final_lr = None, callbacks = []):
        
    encoded_train_set, encoded_valid_set = encode_train_and_valid_sets(train_seqs, train_raw_Y, valid_seqs, valid_raw_Y, input_encoder, output_spec, seq_len)
        
    if begin_with_frozen_pretrained_layers:
        log('Training with frozen pretrained layers...')
        model_generator.train(encoded_train_set, encoded_valid_set, seq_len, batch_size, max_epochs_per_stage, lr = lr_with_frozen_pretrained_layers, \
                callbacks = callbacks, freeze_pretrained_layers = True)
     
    log('Training the entire fine-tuned model...')
    model_generator.train(encoded_train_set, encoded_valid_set, seq_len, batch_size, max_epochs_per_stage, lr = lr, callbacks = callbacks, \
            freeze_pretrained_layers = False)
                
    if n_final_epochs > 0:
        log('Training on final epochs of sequence length %d...' % final_seq_len)
        final_batch_size = max(int(batch_size / (final_seq_len / seq_len)), 1)
        encoded_train_set, encoded_valid_set = encode_train_and_valid_sets(train_seqs, train_raw_Y, valid_seqs, valid_raw_Y, input_encoder, output_spec, final_seq_len)
        model_generator.train(encoded_train_set, encoded_valid_set, final_seq_len, final_batch_size, n_final_epochs, lr = final_lr, callbacks = callbacks, \
                freeze_pretrained_layers = False)
                
    model_generator.optimizer_weights = None

def evaluate_by_len(model_generator, input_encoder, output_spec, seqs, raw_Y, start_seq_len = 512, start_batch_size = 32, increase_factor = 2):
    
    assert model_generator.optimizer_weights is None
    
    dataset = pd.DataFrame({'seq': seqs, 'raw_y': raw_Y})
        
    results = []
    results_names = []
    y_trues = []
    y_preds = []
    
    for len_matching_dataset, seq_len, batch_size in split_dataset_by_len(dataset, start_seq_len = start_seq_len, start_batch_size = start_batch_size, \
            increase_factor = increase_factor):

        X, y_true, sample_weights = encode_dataset(len_matching_dataset['seq'], len_matching_dataset['raw_y'], input_encoder, output_spec, \
                seq_len = seq_len, needs_filtering = False)
        
        assert set(np.unique(sample_weights)) <= {0.0, 1.0}
        y_mask = (sample_weights == 1)
        
        model = model_generator.create_model(seq_len)
        y_pred = model.predict(X, batch_size = batch_size)
        
        y_true = y_true[y_mask].flatten()
        y_pred = y_pred[y_mask]
        
        if output_spec.output_type.is_categorical:
            y_pred = y_pred.reshape((-1, y_pred.shape[-1]))
        else:
            y_pred = y_pred.flatten()
        
        results.append(get_evaluation_results(y_true, y_pred, output_spec))
        results_names.append(seq_len)
        
        y_trues.append(y_true)
        y_preds.append(y_pred)
        
    y_true = np.concatenate(y_trues, axis = 0)
    y_pred = np.concatenate(y_preds, axis = 0)
    all_results, confusion_matrix = get_evaluation_results(y_true, y_pred, output_spec, return_confusion_matrix = True)
    results.append(all_results)
    results_names.append('All')
    
    results = pd.DataFrame(results, index = results_names)
    results.index.name = 'Model seq len'
    
    return results, confusion_matrix

def get_evaluation_results(y_true, y_pred, output_spec, return_confusion_matrix = False):

    from scipy.stats import spearmanr
    from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
            
    results = {}
    results['# records'] = len(y_true)
            
    if output_spec.output_type.is_numeric:
        results['Spearman\'s rank correlation'] = spearmanr(y_true, y_pred)[0]
        confusion_matrix = None
    else:
    
        str_unique_labels = list(map(str, output_spec.unique_labels))
        
        if output_spec.output_type.is_binary:
            
            y_pred_classes = (y_pred >= 0.5)
            
            if len(np.unique(y_true)) == 2:
                results['AUC'] = roc_auc_score(y_true, y_pred)
            else:
                results['AUC'] = np.nan
        elif output_spec.output_type.is_categorical:
            y_pred_classes = y_pred.argmax(axis = -1)
            results['Accuracy'] = accuracy_score(y_true, y_pred_classes)
        else:
            raise ValueError('Unexpected output type: %s' % output_spec.output_type)
                    
        confusion_matrix = pd.DataFrame(confusion_matrix(y_true, y_pred_classes, labels = np.arange(output_spec.n_unique_labels)), index = str_unique_labels, \
                    columns = str_unique_labels)
         
    if return_confusion_matrix:
        return results, confusion_matrix
    else:
        return results
        
def encode_train_and_valid_sets(train_seqs, train_raw_Y, valid_seqs, valid_raw_Y, input_encoder, output_spec, seq_len):
    
    encoded_train_set = encode_dataset(train_seqs, train_raw_Y, input_encoder, output_spec, seq_len = seq_len, needs_filtering = True, \
            dataset_name = 'Training set')
    
    if valid_seqs is None and valid_raw_Y is None:
        encoded_valid_set = None
    else:
        encoded_valid_set = encode_dataset(valid_seqs, valid_raw_Y, input_encoder, output_spec, seq_len = seq_len, needs_filtering = True, \
                dataset_name = 'Validation set')

    return encoded_train_set, encoded_valid_set
        
def encode_dataset(seqs, raw_Y, input_encoder, output_spec, seq_len = 512, needs_filtering = True, dataset_name = 'Dataset', verbose = True):
    
    if needs_filtering:
        dataset = pd.DataFrame({'seq': seqs, 'raw_Y': raw_Y})
        dataset = filter_dataset_by_len(dataset, seq_len = seq_len, dataset_name = dataset_name, verbose = verbose)
        seqs = dataset['seq']
        raw_Y = dataset['raw_Y']
    
    X = input_encoder.encode_X(seqs, seq_len)
    Y, sample_weigths = encode_Y(raw_Y, output_spec, seq_len = seq_len)
    return X, Y, sample_weigths

def encode_Y(raw_Y, output_spec, seq_len = 512):
    if output_spec.output_type.is_seq:
        return encode_seq_Y(raw_Y, seq_len, output_spec.output_type.is_binary, output_spec.unique_labels)
    elif output_spec.output_type.is_categorical:
        return encode_categorical_Y(raw_Y, output_spec.unique_labels), np.ones(len(raw_Y))
    elif output_spec.output_type.is_numeric or output_spec.output_type.is_binary:
        return raw_Y.values.astype(float), np.ones(len(raw_Y))
    else:
        raise ValueError('Unexpected output type: %s' % output_spec.output_type)

def encode_seq_Y(seqs, seq_len, is_binary, unique_labels):

    label_to_index = {str(label): i for i, label in enumerate(unique_labels)}

    Y = np.zeros((len(seqs), seq_len), dtype = int)
    sample_weigths = np.zeros((len(seqs), seq_len))
    
    for i, seq in enumerate(seqs):
        
        for j, label in enumerate(seq):
            # +1 to account for the <START> token at the beginning.
            
            Y[i, j + 1] = label_to_index[label]
            
        sample_weigths[i, 1:(len(seq) + 1)] = 1
        
    if is_binary:
        Y = np.expand_dims(Y, axis = -1)
        sample_weigths = np.expand_dims(sample_weigths, axis = -1)
    
    return Y, sample_weigths
    
def encode_categorical_Y(labels, unique_labels):
    
    label_to_index = {label: i for i, label in enumerate(unique_labels)}
    Y = np.zeros(len(labels), dtype = int)
    
    for i, label in enumerate(labels):
        Y[i] = label_to_index[label]
        
    return Y
    
def filter_dataset_by_len(dataset, seq_len = 512, seq_col_name = 'seq', dataset_name = 'Dataset', verbose = True):
    
    max_allowed_input_seq_len = seq_len - ADDED_TOKENS_PER_SEQ
    filtered_dataset = dataset[dataset[seq_col_name].str.len() <= max_allowed_input_seq_len]
    n_removed_records = len(dataset) - len(filtered_dataset)
    
    if verbose:
        log('%s: Filtered out %d of %d (%.1f%%) records of lengths exceeding %d.' % (dataset_name, n_removed_records, len(dataset), 100 * n_removed_records / len(dataset), \
                max_allowed_input_seq_len))
    
    return filtered_dataset
    
def split_dataset_by_len(dataset, seq_col_name = 'seq', start_seq_len = 512, start_batch_size = 32, increase_factor = 2):

    seq_len = start_seq_len
    batch_size = start_batch_size
    
    while len(dataset) > 0:
        max_allowed_input_seq_len = seq_len - ADDED_TOKENS_PER_SEQ
        len_mask = (dataset[seq_col_name].str.len() <= max_allowed_input_seq_len)
        len_matching_dataset = dataset[len_mask]
        yield len_matching_dataset, seq_len, batch_size
        dataset = dataset[~len_mask]
        seq_len *= increase_factor
        batch_size = max(batch_size // increase_factor, 1)


AttributeError: module 'numpy' has no attribute 'typeDict'

In [None]:
# Step 1: Load your dataset
# Replace these file paths with the paths to your own training and testing data
TRAIN_DATA_PATH = 'train_protein.csv'
TEST_DATA_PATH = 'test_protein.csv'

# Load the dataset
train_set = pd.read_csv(TRAIN_DATA_PATH).dropna().drop_duplicates()
test_set = pd.read_csv(TEST_DATA_PATH).dropna().drop_duplicates()

In [None]:

# Split the training set into train and validation sets
train_set, valid_set = train_test_split(train_set, test_size=0.2, random_state=42)

print(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')


In [None]:
train_set['label']

In [None]:
# Extract unique labels from your dataset
all_labels = set()
for seq in train_set['label']:
    all_labels.update(seq)
for seq in valid_set['label']:
    all_labels.update(seq)

# Convert the set of all labels to a sorted list
UNIQUE_LABELS = sorted(list(all_labels))

# Update the OUTPUT_SPEC with the correct labels
OUTPUT_SPEC = OutputSpec(OutputType(True, 'categorical'), UNIQUE_LABELS)


In [None]:
# Step 2: Define Output Specification
# Customize this based on your dataset's labels
# Example for binary classification
# Set output type as text

label_list = [x for x in train_set['label']]


all_labels = set()
for seq in train_set['label']:
    all_labels.update(seq)
for seq in valid_set['label']:
    all_labels.update(seq)
    
label_list.append(all_labels)


OUTPUT_TYPE = OutputType(True, 'categorical')

# Define unique labels as needed
UNIQUE_LABELS = label_list  # Adjust these to match your dataset's labels

# Create the output specification
OUTPUT_SPEC = OutputSpec(OUTPUT_TYPE, UNIQUE_LABELS)




In [None]:
# Step 3: Load Pre-trained ProteinBERT Model
pretrained_model_generator, input_encoder = load_pretrained_model()

# Step 4: Create the Fine-tuning Model
model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC, 
    pretraining_model_manipulation_function=get_model_with_hidden_layers_as_outputs, dropout_rate=0.5)

# Step 5: Set up Callbacks for Training
training_callbacks = [
    keras.callbacks.ReduceLROnPlateau(patience=1, factor=0.25, min_lr=1e-05, verbose=1),
    keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True),
]

# Step 6: Fine-tune the Model
finetune(model_generator, input_encoder, OUTPUT_SPEC, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], 
    seq_len=512, batch_size=32, max_epochs_per_stage=1, lr=1e-04, begin_with_frozen_pretrained_layers=True, 
    lr_with_frozen_pretrained_layers=1e-02, n_final_epochs=1, final_seq_len=1024, final_lr=1e-05, callbacks=training_callbacks)

# Step 7: Evaluate the Model on the Test Set
results, confusion_matrix = evaluate_by_len(model_generator, input_encoder, OUTPUT_SPEC, test_set['seq'], test_set['label'], 
    start_seq_len=512, start_batch_size=32)

print('Test-set performance:')
print(results)
print('Confusion matrix:')
print(confusion_matrix)

In [None]:
# Example: Predict on new data
# Replace 'new_sequences' with your list of sequences to predict on
new_sequences = ['MVLSPADKTNVKAAW', 'GVLTQSQAELERVH']

# If your data is in a DataFrame:
# new_data = pd.DataFrame({'seq': ['MVLSPADKTNVKAAW', 'GVLTQSQAELERVH']})


In [None]:
from proteinbert import predict

# Predictions (if you have a DataFrame or list of sequences)
predictions = predict(model_generator, input_encoder, OUTPUT_SPEC, new_sequences, seq_len=512, batch_size=32)

# Output the predictions
print(predictions)


In [None]:
assert os.path.exists(TRAIN_DATA_PATH), f"{TRAIN_DATA_PATH} does not exist."
assert os.path.exists(TEST_DATA_PATH), f"{TEST_DATA_PATH} does not exist."


In [None]:
print(train_set.head())


In [None]:

print(test_set.head())



In [None]:

print(train_set.info())


In [None]:
print(f'{len(train_set)} training set records, {len(valid_set)} validation set records.')


In [None]:
print(f'Unique labels: {UNIQUE_LABELS}')


In [1]:
import os
import pandas as pd
from tensorflow import keras
from sklearn.model_selection import train_test_split
from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, finetune, evaluate_by_len
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

In [2]:
# Step 1: Load your dataset
TRAIN_DATA_PATH = 'train_protein.csv'
TEST_DATA_PATH = 'test_protein.csv'

# Load the dataset
train_set = pd.read_csv(TRAIN_DATA_PATH).dropna().drop_duplicates()
test_set = pd.read_csv(TEST_DATA_PATH).dropna().drop_duplicates()

In [3]:
# Split the training set into train and validation sets
train_set, valid_set = train_test_split(train_set, test_size=0.2, random_state=42)

print(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')


2563 training set records, 641 validation set records, 1375 test set records.


In [4]:
# Extract unique labels from your dataset
all_labels = set()
for label in train_set['label']:
    all_labels.add(label)
for label in valid_set['label']:
    all_labels.add(label)

# Convert the set of all labels to a sorted list
UNIQUE_LABELS = sorted(list(all_labels))

# Update the OUTPUT_SPEC with the correct labels
OUTPUT_SPEC = OutputSpec(OutputType(False, 'categorical'), UNIQUE_LABELS)
# False is for isBinary


In [6]:
# Step 3: Load Pre-trained ProteinBERT Model
pretrained_model_generator, input_encoder = load_pretrained_model()

# Step 4: Create the Fine-tuning Model
model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC, 
    pretraining_model_manipulation_function=get_model_with_hidden_layers_as_outputs, dropout_rate=0.5)

# Step 5: Set up Callbacks for Training
training_callbacks = [
    keras.callbacks.ReduceLROnPlateau(patience=1, factor=0.25, min_lr=1e-05, verbose=1),
    keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True),
]

# Step 6: Fine-tune the Model
finetune(model_generator, input_encoder, OUTPUT_SPEC, train_set['seq'], train_set['label'], valid_set['seq'], valid_set['label'], 
    seq_len=1024, batch_size=32, max_epochs_per_stage=1, lr=1e-04, begin_with_frozen_pretrained_layers=True, 
    lr_with_frozen_pretrained_layers=1e-02, n_final_epochs=1, final_seq_len=1024, final_lr=1e-05, callbacks=training_callbacks)

# Step 7: Evaluate the Model on the Test Set
results, confusion_matrix = evaluate_by_len(model_generator, input_encoder, OUTPUT_SPEC, test_set['seq'], test_set['label'], 
    start_seq_len=1024, start_batch_size=32)

print('Test-set performance:')
print(results)
print('Confusion matrix:')
print(confusion_matrix)

[2024_08_09-00:40:07] Training set: Filtered out 32 of 2563 (1.2%) records of lengths exceeding 1022.
[2024_08_09-00:40:07] Validation set: Filtered out 11 of 641 (1.7%) records of lengths exceeding 1022.
[2024_08_09-00:40:07] Training with frozen pretrained layers...
145
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 754ms/step - loss: 8.1408 - val_loss: 8.5370 - learning_rate: 2.0000e-04
[2024_08_09-00:41:13] Training the entire fine-tuned model...
146


ValueError: You called `set_weights(weights)` on layer 'functional_2' with a weight list of length 146, but the layer was expecting 145 weights.

In [7]:
# pytorch

In [10]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Load the dataset
TRAIN_DATA_PATH = 'train_protein.csv'
TEST_DATA_PATH = 'test_protein.csv'

train_set = pd.read_csv(TRAIN_DATA_PATH).dropna().drop_duplicates()
test_set = pd.read_csv(TEST_DATA_PATH).dropna().drop_duplicates()

# Split the training set into train and validation sets
train_set, valid_set = train_test_split(train_set, test_size=0.2, random_state=42)

# Encode the labels
label_encoder = LabelEncoder()
train_set['label'] = label_encoder.fit_transform(train_set['label'])
valid_set['label'] = label_encoder.transform(valid_set['label'])
test_set['label'] = label_encoder.transform(test_set['label'])


ValueError: y contains previously unseen labels: 'YJAA_ECOLI'