In [36]:
# ! pip install torcheval

In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, BertConfig, logging
import gc
import numpy as np
import pandas as pd
import warnings
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras as K

import torch
import torch.nn as nn
import torch.nn.functional as F
from torcheval.metrics.functional import binary_auroc, binary_auprc
from torchmetrics import AUROC

warnings.filterwarnings('ignore')
logging.set_verbosity_error()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from biom import load_table, Table

In [20]:
class GeneratorDataset:
    def __init__(
        self,
        table = None,
        metadata = None,
        metadata_column = None,
        shift = None,
        scale = "minmax",
        max_token_per_sample: int = 1024,
        shuffle: bool = False,
        rarefy_depth: int = 5000,
        epochs: int = 1000,
        gen_new_tables: bool = False,
        batch_size: int = 8,
        max_bp: int = 150,
        is_16S: bool = True,
        is_categorical = None,
        gen_new_table_frequency=3,
        return_sample_ids=False,
        tree_path=None,
        seed=None,
    ):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # initialize tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
        self.config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
        self.model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=self.config).to(device)
        self.model.requires_grad = False

        if isinstance(table, str):
            table = load_table(table)
        self.table: Table = table
        self.tree_path = tree_path
        self.is_categorical: bool = is_categorical
        self.metadata_column: str = metadata_column
        self.shift = shift
        self.scale = scale
        self.metadata: pd.Series = metadata
        self.rarefy_depth: int = rarefy_depth
        self.max_token_per_sample: int = max_token_per_sample
        self.return_sample_ids: bool = return_sample_ids
        self.include_sample_weight: bool = is_categorical
        self.shuffle = shuffle
        self.epochs = epochs
        self.gen_new_tables = gen_new_tables
        self.samples_per_minibatch = batch_size
        self.batch_size = batch_size
        self.max_bp = max_bp
        self.is_16S = is_16S
        self.seed = seed
        self.gen_new_table_frequency = gen_new_table_frequency
        self.epochs_since_last_table = 0
        self.encoder_target = None
        self.encoder_dtype = None
        self.encoder_output_type = None
        self.sample_ids = None
        self.asv_ids = None
        if self.tree_path is not None:
            self.tree = to_skbio_treenode(parse_newick(open(self.tree_path).read()))
            self.postorder_pos = {n.name: i for i, n in enumerate(self.tree.postorder()) if n.is_tip()}
        print("rarefy table...")
        self.rarefied_table: Table = self.table.subsample(rarefy_depth)
        self.size = self.rarefied_table.shape[1]
        self.steps_per_epoch = self.size // self.batch_size
        self.y_data = self.metadata.loc[self._rarefied_table.ids()]
        self.on_epoch_end()

    def __len__(self):
        return self.steps_per_epoch
    
    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = start + self.batch_size
        sample_indices = self.sample_indices[start:end]
        batch_sample_ids = self.sample_ids[sample_indices]
        return self._batch_data(batch_sample_ids)
    
    def _batch_data(self, batch_sample_ids):
        num_unique_asvs, sparse_indices, obs_indices, counts = [], [], [], []
        cur_row_indx = 0
        for s_id in batch_sample_ids:
            sample_data = self.rarefied_table.data(s_id, dense=False).tocoo()
            obs_idx, sample_counts = sample_data.row, sample_data.data
            # remove zeros
            non_zero_mask = sample_counts > 0.0
            obs_idx = obs_idx[non_zero_mask]
            sample_counts = sample_counts[non_zero_mask]
            num_unique_asvs.append(len(obs_idx))
            sparse_indices.append(([[cur_row_indx, i] for i in range(len(obs_idx))])[:2])
            obs_indices.append(obs_idx[:2])
            counts.append(sample_counts[:2])
            cur_row_indx += 1
        num_unique_asvs = np.array(num_unique_asvs, dtype=np.int32)
        sparse_indices = np.vstack(sparse_indices, dtype=np.int32)
        obs_indices = np.hstack(obs_indices, dtype=np.int32)
        counts = np.hstack(counts, dtype=np.float32)[:, np.newaxis]
        
        # get list of unique observations in batch
        unique_obs, obs_indices = np.unique(obs_indices, return_inverse=True)
        obs = self.rarefied_table.ids(axis="observation")
        tokens = obs[unique_obs]
        y_true = self.y_data.loc[batch_sample_ids].to_numpy()[:, np.newaxis]
        # return (tokens, sparse_indices, obs_indices, counts), y_true # [unknown num of unique seq, N]
        return self.calc_embedding_mean(tokens, sparse_indices, obs_indices, counts), y_true
    
    def batch_embeddings(self, asv_embeddings, batch_indicies, counts, asv_indices=None):
        emb_dim = tf.shape(asv_embeddings)[-1]
        if asv_indices is not None:
            asv_embeddings = tf.gather(asv_embeddings, asv_indices)
        batch_shape = tf.reduce_max(batch_indicies[:, 0]) + 1
        max_unique = tf.reduce_max(batch_indicies[:, 1]) + 1
        batch_embeddings = tf.scatter_nd(
            batch_indicies, asv_embeddings, shape=[batch_shape, max_unique, emb_dim]
        )
        counts = tf.scatter_nd(
            batch_indicies, counts, shape=[batch_shape, max_unique, 1]
        )
        return batch_embeddings.numpy(), counts.numpy()

    def calc_embedding_mean(self, tokens, sparse_indices, obs_indices, counts):
        '''
        input: batch_data output
        returns: [B, A, E]
        '''
        # tokens = tokens[:2]
        inputs = [self.tokenizer(token, return_tensors = 'pt')["input_ids"].to(device) for token in tokens]
        hidden_states = [self.model(input) for input in inputs] # shape: [B x A, N, E]
        embedding_mean = [torch.mean(byte_pair, dim=1) for byte_pair, class_token in hidden_states] # embedding with mean pooling
        embeddings, counts = self.batch_embeddings(torch.concat(embedding_mean, dim=0).detach().numpy(), sparse_indices, counts, obs_indices) # shape: [num of unique sequences, E]
        return embeddings, counts

        
    def sort_using_counts(self, tensor, counts):
        sorted_indices = tf.argsort(tf.squeeze(counts, axis=-1), axis=1, direction="DESCENDING")
        sorted_tensor = tf.gather(tensor, sorted_indices, axis=1, batch_dims=1)
        sorted_counts = tf.gather(counts, sorted_indices, axis=1, batch_dims=1)
        return sorted_tensor, sorted_counts

    def on_epoch_end(self):
        if self.gen_new_tables and self.epochs_since_last_table > self.gen_new_table_frequency:
            print("resampling dataset...")
            self.rarefied_table = self.table.subsample(self.rarefy_depth)
            self.epochs_since_last_table = 0
        if self.shuffle:
            np.random.shuffle(self.sample_indices)
        self.epochs_since_last_table += 1
    
    @property
    def rarefied_table(self):
        return self._rarefied_table
    
    @rarefied_table.setter
    def rarefied_table(self, table: Table):
        self._rarefied_table = table
        print("removing empty sample/obs from table")
        self._rarefied_table.remove_empty()
        if self.tree_path is not None:
            def sort_obs(obs):
                post_pos = [self.postorder_pos[ob] for ob in obs]
                sorted_indices = np.argsort(post_pos)
                return obs[sorted_indices]
            self._rarefied_table = self._rarefied_table.sort(sort_obs, axis="observation")
        self.sample_ids = self._rarefied_table.ids()
        self.asv_ids = self._rarefied_table.ids(axis="observation")
        self.sample_indices = np.arange(len(self.sample_ids))
        print("creating encoder target...")
        self.encoder_target = self._create_encoder_target()
        print("encoder target created")

    def _create_encoder_target(self) -> None:
        return None
    
    def _encoder_output(self, sample_ids):
        return None
    
    @property
    def table(self) -> Table:
        return self._table
    @table.setter
    def table(self, table):
        self._table = table
    @property
    def metadata(self) -> pd.Series:
        return self._metadata
    @metadata.setter
    def metadata(self, metadata):
        if metadata is None:
            return
        if isinstance(metadata, str):
            metadata = pd.read_csv(metadata, sep="\t", index_col=0, dtype={0: str})
        if self.metadata_column not in metadata.columns:
            raise Exception(f"Invalid metadata column {self.metadata_column}")
        print("aligning table with metadata")
        samp_ids = np.intersect1d(self.table.ids(axis="sample"), metadata.index)
        self.table.filter(samp_ids, axis="sample", inplace=True)
        self.table.remove_empty()
        metadata = metadata.loc[self.table.ids(), self.metadata_column]
        print(f"aligned table shape: {self.table.shape}")
        print(f"aligned metadata shape: {metadata.shape}")
        metadata = metadata.astype(np.int32)
        self._metadata = metadata.reindex(self.table.ids())
        print("done preprocessing metadata")

In [21]:
gen = GeneratorDataset("data/input/merged_biom_table.biom", "data/input/training_metadata.tsv", "has_covid")

aligning table with metadata
aligned table shape: (26778, 269)
aligned metadata shape: (269,)
done preprocessing metadata
rarefy table...
removing empty sample/obs from table
creating encoder target...
encoder target created


In [22]:
gen[0]

((array([[[-0.10416535,  0.07718358,  0.07524616, ...,  0.03523894,
            0.06350107,  0.11642157],
          [-0.06947182,  0.07317469,  0.06933934, ...,  0.04492776,
            0.09360398,  0.08222608]],
  
         [[-0.05706431,  0.04534246,  0.09490667, ..., -0.01283019,
            0.04796402,  0.07663572],
          [-0.07141566,  0.05508459,  0.02317847, ...,  0.02454554,
            0.05045987,  0.14504033]],
  
         [[-0.08008792,  0.13445865,  0.16800281, ...,  0.03898952,
            0.07082699,  0.09655195],
          [-0.08790375,  0.12666956,  0.17301773, ...,  0.02596849,
            0.06304552,  0.10835315]],
  
         ...,
  
         [[-0.04361127,  0.08778518,  0.0632642 , ..., -0.02247675,
            0.10746151,  0.03641083],
          [-0.05611859,  0.0719678 ,  0.05996731, ..., -0.01507601,
            0.10763866,  0.02790567]],
  
         [[-0.09910258,  0.09380104,  0.12231942, ..., -0.04979033,
            0.05987035,  0.09425615],
          [-0

In [None]:
def get_embedding():
    '''
    output: [B, A, E]
    '''
    # set device
    device = "cuda" if torch.cuda.is_available() else "cpu"
        
    # initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
    config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
    model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config).to(device)
    model.requires_grad = False
    
    def batch_embeddings(asv_embeddings, batch_indicies, counts, asv_indices=None):
        emb_dim = tf.shape(asv_embeddings)[-1]
        if asv_indices is not None:
            asv_embeddings = tf.gather(asv_embeddings, asv_indices)
        batch_shape = tf.reduce_max(batch_indicies[:, 0]) + 1
        max_unique = tf.reduce_max(batch_indicies[:, 1]) + 1
        batch_embeddings = tf.scatter_nd(
            batch_indicies, asv_embeddings, shape=[batch_shape, max_unique, emb_dim]
        )
        counts = tf.scatter_nd(
            batch_indicies, counts, shape=[batch_shape, max_unique, 1]
        )
        return batch_embeddings.numpy(), counts.numpy()

    def calc_embedding_mean(tokens, sparse_indices, obs_indices, counts):
        '''
        input: batch_data output
        returns: [B, A, E]
        '''
        # tokens = tokens[:2]
        inputs = [tokenizer(token, return_tensors = 'pt')["input_ids"].to(device) for token in tokens]
        print(inputs[1].shape)
        # inputs = torch.concat(inputs, dim=0).to(device) # [number of unique sequences, N] where N is byte pairs
        hidden_states = [model(input) for input in inputs] # shape: [B x A, N, E]
        embedding_mean = [torch.mean(byte_pair, dim=1) for byte_pair, class_token in hidden_states] # embedding with mean pooling
        embeddings, counts = batch_embeddings(torch.concat(embedding_mean, dim=0).detach().numpy(), sparse_indices, counts, obs_indices) # shape: [num of unique sequences, E]
        return embeddings, counts
    
    return calc_embedding_mean

In [19]:
feature_extractor = get_embedding()
x, y = gen[1]
inputs = feature_extractor(*x)
print(inputs[0])

torch.Size([1, 36])
tensor([[[-0.0403,  0.0607,  0.1705,  ...,  0.0023,  0.1198,  0.0257],
         [-0.0627,  0.0756,  0.0871,  ...,  0.0529,  0.0544,  0.0965]],

        [[-0.0791,  0.0672,  0.0910,  ...,  0.0481,  0.0366,  0.1081],
         [-0.0714,  0.0551,  0.0232,  ...,  0.0245,  0.0505,  0.1450]],

        [[-0.0532,  0.0889,  0.0643,  ..., -0.0142,  0.1133,  0.0303],
         [-0.0561,  0.0720,  0.0600,  ..., -0.0151,  0.1076,  0.0279]],

        ...,

        [[-0.0561,  0.0720,  0.0600,  ..., -0.0151,  0.1076,  0.0279],
         [-0.0791,  0.0672,  0.0910,  ...,  0.0481,  0.0366,  0.1081]],

        [[-0.0727,  0.0544,  0.1143,  ..., -0.0174,  0.1117,  0.1183],
         [-0.0626,  0.0481,  0.0178,  ..., -0.0067,  0.1115,  0.1134]],

        [[-0.0561,  0.0720,  0.0600,  ..., -0.0151,  0.1076,  0.0279],
         [-0.0791,  0.0672,  0.0910,  ...,  0.0481,  0.0366,  0.1081]]])


In [79]:
# lt, rt = inputs[1]
# print("token shape:", gen[1][0][0].shape)
# print("lt:", lt.shape)
# print("rt:", rt.shape)
print(inputs[1].shape)

torch.Size([1, 768])


In [None]:
class TransformerEncodingBlock(nn.Module):
    def __init__(self, num_attention_heads=4, dropout_rate=0.1, **kwargs): # change num_attention_heads to whatever dnabert 2 uses
        super().__init__(**kwargs)
        self.num_attention_heads = num_attention_heads

        emb_dim = 768
        key_dim = emb_dim // self.num_attention_heads
        self.attention_layer = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=self.num_attention_heads)
        self.ffi = nn.Linear(768, 3072)
        self.ffo = nn.Linear(3072, 768)
        self.activation = nn.GELU()
        self.rezero_alpha = nn.Parameter(torch.zeros(1), requires_grad=True)
        self.ff_dropout_rate = nn.Dropout(p=self.dropout_rate)
        self.attention_dropout = nn.Dropout(p=self.dropout_rate)

    def forward(self, input, attention_mask=None, training=False):
        """
        input: [B, A, N]
        attention_mask: [B, A, 1]
        """
        # attention
        attention_mask = torch.matmul(attention_mask, attention_mask.transpose(-1, -2))
        attention_output, _ = self.attention_layer(input, input, input, attn_mask=attention_mask)
        
        # rezero skip connection
        if training:
            attention_output = self.attention_dropout(attention_output)
        attention_output = input + self.rezero_alpha * attention_output

        # feed forward
        ffi_output = self.ffi(attention_output)
        ffi_activation = self.activation(ffi_output)
        ffo_output = self.ffo(ffi_activation)
        ffo_activation = self.activation(ffo_output)

        # dropout
        if training:
            ffo_activation = self.ff_dropout_rate(ffo_activation)

        return attention_output + self.rezero_alpha * ffo_activation


class TransformerEncoder(nn.Module):
    def __init__(self, num_attention_layers=4, num_attention_heads=4, dropout_rate=0.1):
        super(TransformerEncoder, self).__init__()
        self.num_attention_layers = num_attention_layers
        self.num_attention_heads = num_attention_heads
        self.dropout_rate = dropout_rate

        self.encoding_layers = nn.ModuleList([
            TransformerEncodingBlock(self.num_attention_heads, self.dropout_rate) for _ in range(self.num_attention_layers)
        ])
        self.rezero_alpha = nn.Parameter(torch.zeros(1), requires_grad=True)

    def forward(self, inputs, attention_mask=None, training=False):
        """
        """
        output = inputs
        for i in range(self.num_attention_layers):
            output = self.encoding_layers[i](output, attention_mask=attention_mask, training=training)
        
        return inputs * self.rezero_alpha + output


In [None]:
class Classifier(nn.Module):
    def __init__(self, feature_extractor, num_classes=2, **kwargs):
        super(Classifier, self).__init__()

        self.feature_extractor = feature_extractor
        self.feature_extractor.requires_grad = False

        self.encoder = TransformerEncoder()
        self.pooling = lambda x: torch.mean(x, dim=-1)
        # self.tokenizer
        self.dense_ff = nn.Linear(in_features=516, out_features=1)
        self.loss_fn = nn.BCEWithLogitsLoss()
        
        # Initialize AUC metric
        self.auc_metric = AUROC(task="binary", num_classes=num_classes)

    def forward(self, inputs, mask=None):
        """
        inputs: [B, A, N] - Batch, ASV, Nucleotides
        """
        features = self.feature_extractor(inputs)  # [B, A, E]

        encoding_output = self.encoder(features, attention_mask=mask)  # [B, A, N]
        pooling_output = self.pooling(encoding_output)  # [B, N]
        logits = self.dense_ff(pooling_output)  
        return logits #  classification

In [26]:
X_test = pd.read_csv('data/input/samples_X_test.csv').drop(columns=['sample_name', 'study_sample_type'])[:2]
y_test = pd.read_csv('data/input/samples_y_test.csv')[:2]

In [78]:
y_test_tensor = torch.tensor(y_test.to_numpy(), dtype=torch.float32)
y_test_tensor = y_test_tensor.view(-1, 1)
y_test_tensor

tensor([[0.],
        [1.]])

In [82]:
model = Classifier(get_embedding).to(device)
criterion = nn.BCEWithLogitsLoss()
auc_metric = AUROC(task="binary")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
epochs = 1

for epoch in tqdm(range(epochs)):
    model.train()

    # forward pass
    output = model(X_test)

    # compute loss and update metric
    loss = criterion(output, y_test_tensor)
    auc_metric.update(output, y_test_tensor)

    # optimizer zero grad
    optimizer.zero_grad()
    
    # loss backward
    loss.backward()

    # optimizer step
    optimizer.step()

    print(f"Epoch {epoch+1}, loss: {loss.item()}")

  0%|          | 0/1 [00:00<?, ?it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 195/195 [00:11<00:00, 17.31it/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Epoch 1, loss: 0.6927562952041626





In [76]:
final_auc = auc_metric.compute()

print(f"Final AUC: {final_auc.item()}")


Final AUC: 0.8333333134651184
