In [1]:
import torch.nn
import torch

In [4]:
import pandas as pd

dataset = pd.read_feather("../data/filtered_dataset.ftr")

In [8]:
inputs = dataset['sequence'].tolist()

In [17]:
from transformers import AutoModel, AutoTokenizer
from typing import List

class LinearClassificationModule(torch.nn.Module):

    def __init__(self, input_size, output_size):
        super(LinearClassificationModule, self).__init__()
        self.net = torch.nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.net(x)

class TwoLayerClassificationModule(torch.nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(TwoLayerClassificationModule, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(input_size, hidden_size),
            torch.nn.SiLU(),
            torch.nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.net(x)


class IdentityDNAFeatureTransformer(torch.nn.Module):
    def __init__(self):
        super(IdentityDNAFeatureTransformer, self).__init__()

    def forward(self, x):
        return x


class Net(torch.nn.Module):
    def __init__(self, base_model="zhihan1996/DNABERT-S", dna_feature_transformer=None, classification_module=None):
        super(Net, self).__init__()

        self.dna_tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
        self.dna_encoder_model = AutoModel.from_pretrained(base_model, trust_remote_code=True)

        if dna_feature_transformer is not None:
            self.dna_feature_transformer = dna_feature_transformer
        else:
            self.dna_feature_transformer = IdentityDNAFeatureTransformer()

        if classification_module is None:
            raise ValueError("classification_module must be provided! ")

        self.classification_module = classification_module

    def encode_dna(self, dna_sequences: List[str]):
        inputs = self.dna_tokenizer(dna_sequences, return_tensors="pt", padding=True)
        outputs = self.dna_encoder_model(**inputs)
        embeddings = outputs[0].mean(dim=1)
        return embeddings


    def forward(self, sequences):

        embeddings = self.encode_dna(sequences)
        transformed_embeddings = self.dna_feature_transformer(embeddings)
        predictions = self.classification_module(transformed_embeddings)

        return predictions


In [19]:
cls_module = LinearClassificationModule(768, 4)
two_layer_cls_module = TwoLayerClassificationModule(768, 256, 4)
net = Net(classification_module=two_layer_cls_module)



In [20]:
net(inputs[:10])

tensor([[ 0.0404, -0.0223, -0.0001,  0.0614],
        [ 0.0256, -0.0333,  0.0324,  0.0752],
        [ 0.0402, -0.0365,  0.0054,  0.0722],
        [ 0.0440, -0.0327,  0.0202,  0.0353],
        [ 0.0515, -0.0527,  0.0188,  0.0783],
        [ 0.0470, -0.0433,  0.0090,  0.0420],
        [ 0.0185, -0.0294,  0.0264,  0.0587],
        [ 0.0360, -0.0324,  0.0335,  0.0836],
        [ 0.0328, -0.0277,  0.0102,  0.0621],
        [ 0.0632, -0.0174,  0.0143,  0.0607]], grad_fn=<AddmmBackward0>)