<a href="https://colab.research.google.com/github/shuaishuaigu/Igrna-abe/blob/main/SCG-GELP-Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**The SCG-GELP policy contains the following process:**

- Setting parameters
- SCG-Transformer model to generate synonymous codon sequences
- Biological features and transfer learning (DNABER-2) feature extraction
- Prediction of highly soluble expressed gene sequences using GELP models (SVM, LR, MLP, CNN-N-NF)
- Screening results


### Colab environment settings

**Note: Change Colab runtime hardware gas pedal to T4 GPU**

In [2]:
import triton
print(triton.__version__)

import transformers
print(transformers.__version__)

import torch
print(torch.__version__)

2.0.0
4.48.3
2.0.1+cu117


In [2]:
!pip uninstall triton --yes

Found existing installation: triton 3.1.0
Uninstalling triton-3.1.0:
  Successfully uninstalled triton-3.1.0


In [4]:
# Restart the session after the installation is complete
!pip install triton==2.0.0

Collecting triton==2.0.0
  Downloading triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.0 kB)
Collecting lit (from triton==2.0.0)
  Downloading lit-18.1.8-py3-none-any.whl.metadata (2.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->triton==2.0.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->triton==2.0.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->triton==2.0.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->triton==2.0.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->triton==2.0.0)
  Downloading nvidia_cublas_cu1

In [None]:
import triton
print(triton.__version__)

import transformers
print(transformers.__version__)

2.0.0
4.41.2


In [3]:
!pip install einops



In [4]:
!git clone https://github.com/yuddecho/SCG-GELP.git

Cloning into 'SCG-GELP'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 31 (delta 0), reused 3 (delta 0), pack-reused 27 (from 1)[K
Receiving objects: 100% (31/31), 569.85 MiB | 21.09 MiB/s, done.
Updating files: 100% (24/24), done.


In [5]:
!git clone https://github.com/yuddecho/DNABERT-2-Finetune-1400.git

Cloning into 'DNABERT-2-Finetune-1400'...
remote: Enumerating objects: 43, done.[K
remote: Counting objects: 100% (14/14), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 43 (delta 3), reused 11 (delta 2), pack-reused 29 (from 1)[K
Receiving objects: 100% (43/43), 414.29 MiB | 35.51 MiB/s, done.
Resolving deltas: 100% (4/4), done.
Updating files: 100% (28/28), done.


In [39]:
# merge model weight files
resource_dir = '/content/SCG-GELP'
DNABERT_2_checkpoint = '/content/DNABERT-2-Finetune-1400'

model_weight_source_dir = [DNABERT_2_checkpoint, resource_dir, resource_dir]
source_num_files = [10, 10, 4]
source_big_bin_files = ['pytorch_model.bin', 'nesg-cnn-natural-number-512-512-512-1-16-2-False.pt', 'tf_0.0001_512_8_3_3_512_0.1_1024_6.pt']

def merge_pt(model_path, target_file, num_files):
    print(model_path, target_file, num_files)

    model_file_dir = target_file[:-3]
    file_name = model_file_dir
    tag = target_file[-3:]

    if target_file[-3:] != '.pt':
        model_file_dir = 'pytorch_model_files'
        file_name = target_file[:-4]
        tag = target_file[-4:]

    # merge
    total_contents = bytes()

    for i in range(num_files):
        file_path = f'{model_path}/{model_file_dir}/{file_name}_{i}{tag}'

        with open(file_path, 'rb') as infile:
            total_contents += infile.read()

    # write
    with open(f'{model_path}/{file_name}{tag}', 'wb') as outfile:
        outfile.write(total_contents)

    print(f'{model_path}/{file_name}{tag}')

for path, file, nums in zip(model_weight_source_dir, source_big_bin_files, source_num_files):
    merge_pt(path, file, nums)

/content/DNABERT-2-Finetune-1400 pytorch_model.bin 10
/content/DNABERT-2-Finetune-1400/pytorch_model.bin
/content/SCG-GELP nesg-cnn-natural-number-512-512-512-1-16-2-False.pt 10
/content/SCG-GELP/nesg-cnn-natural-number-512-512-512-1-16-2-False.pt
/content/SCG-GELP tf_0.0001_512_8_3_3_512_0.1_1024_6.pt 4
/content/SCG-GELP/tf_0.0001_512_8_3_3_512_0.1_1024_6.pt


## Setting parameters

In [106]:
# for SCG
import os
import math
import random
import datetime

import numpy as np
import collections
import json
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import log_softmax

from torch import Tensor

from timeit import default_timer as timer
from tqdm.notebook import tqdm

# for biological features
import csv
from collections import Counter
import itertools
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA

# for DNABERT-2
from torch.utils.data import Dataset
import csv
import torch
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

In [132]:
# # Get GitHub project resources
# !git clone git@github.com:yuddecho/SCG-MELP.git

# resource directory
# resource_dir =  './resource'
# DNABERT_2_checkpoint = './resource/DNABERT-2-Finetune-1400'
resource_dir = '/content/SCG-GELP'
DNABERT_2_checkpoint = '/content/DNABERT-2-Finetune-1400'

# Args
protein_name = 'SSS'
protein_seq = 'MRIINVKDYEEM*'
# protein_seq = 'MRIINVKDYEEMSRKAADLIAAQIILNPKSSLGLAHGKSPIGFYERLVELNRNGVIDFSHVTTINLDEYYGLDPTHDQSYRYFMNKHLFSRVNINMANTHLPDGKAKDIDAECIRYDDLIESVGGIDLQLLGIGHNGHIFFNEPSDEFIPGTHCVSLSQSTINANSLMYFSRDEVQRKAITMGIKAIMQARYVQLIASGEDQIEILKKALFGPITPQVPASILQLHKDLTVITPLDI*'
# dna_wt = 'ATGCGTATTATTAACGTTAAAGACTACGAAGAAATG'
dna_wt = 'ATGCGTATCATTAACGTGAAAGACTACGAGGAAATGagcCGTAAGGCGGCGGATCTGATTGCGGCGCAGATCATTCTGAACCCGAAAAGCAGTCTGGGTCTGGCGCatGGCaaaAGCCCGATTGGTTTTTATGAGCGTCTGGTTGAACTGAACCGTAACGGCGTGATCGACTTCAGCCACGTTACCACCATTAACCTGGATGAGTACTATGGTCTGGACCCGACCCACGATCAGAGCTACCGTTATTTCATGAACAAGCACCTGTTTAGCCGTGTGAACATCAACATGGCGAACACCCACCTGCCGGATGGCaaGGCGAAAGACATTGATGCGGAGTGCATTCGTTACGACGATCTGATCGAAAGCGTTGGTGGCATTGACCTGCAACTGCTGGGTATCGGCCACAACGGTCACATTtttTTCAACGAGCCGAGCGATGAATTTATCCCGGGTACCCACTGCGTTAGCCTGAGCCAAAGCACCATTAACGCGAACAGCCTGATGTATTTTAGCCGTGACGAAGTGcaaCGTAAGGCGATCACCATGGGCATCAAAGCGATTATGCAAGCGCGTTATGTTCAGCTGATCGCGAGCGGCGAGGATCAAATTGAAATTCTGAAGAAAGCGCTGTTTGGTCCGATCACCCCGCAGGTGCCGGCGAGCATTCTGCAACTGCACAAGGACCTGACCGTTATCACCCCGCTGGATATTTGA'
dna_gs = 'ATGCGTATCATTAACGTGAAAGACTACGAGGAAATGAGCCGTAAGGCGGCGGATCTGATTGCGGCGCAGATCATTCTGAACCCGAAAAGCGTGCTGGGTCTGGCGACCGGCAGCAGCCCGATTGGTACCTATGAGCGTCTGGTTGAACTGAACCGTAACGGCGTGATCGACTTCAGCCACGTTACCACCATTAACCTGGATGAGTACTATGGTCTGGACCCGACCCACGATCAGAGCTACCGTTATTTCATGAACAAGCACCTGTTTAGCCGTGTGAACATCAACATGGCGAACACCCACCTGCCGGATGGCAAGGCGAAAGACATTGATGCGGAGTGCCGTCGTTACGACGATCTGATCGAAAGCGTTGGTGGCATTGACCTGCAACTGCTGGGTATCGGCCACAACGGTCACATTGGCTTCAACGAGCCGAGCGATGAATTTATCCCGGGTACCCACTGCGTTAGCCTGAGCGAGAGCACCATTAACGCGAACAGCCGTTTCTTTAAAAGCCGTGACGAAGTGCCGCGTAAGGCGATCACCATGGGCATCAAAGCGATTATGCAAGCGCGTAAGGTTCTGCTGATCGCGAGCGGCGAGGATAAGAAAGAAATTCTGAAGAAAGCGCTGTTTGGTCCGATCACCCCGCAGGTGCCGGCGAGCATTCTGCAACTGCACAAGGACCTGACCGTTATCACCCCGCTGGATATTTGA'
# dna_gs = 'ATGCGTATTATTAACGTTAAAGACTACGAAGAAATG'

dna_wt = dna_wt.upper()
dna_gs = dna_gs.upper()

dnas = {f'{protein_name}-WT': dna_wt, f'{protein_name}-GS': dna_gs}

exec_func = {
    'Exec SCG Model': True,
    'Exec DNA Feature': True,
    'Exec DNABERT-2 Model': True,
    'Exec Sklearn model': True,
    'Exec CNN-N-NF model': True
}

# beam search para
beam_batch_sizes, beam_widths, beam_sizes = [4], [5], [8]
# beam_batch_sizes, beam_widths = [1, 2, 4, 8, 16, 32], [2, 2, 3, 4, 5, 5]
# beam_sizes = [val * 20 for val in beam_batch_sizes]

# Sorting by the predictions of the first model, [0, 3]
model_id = 0

# Res dir
res_dir = 'scg-gelp-res'
if not os.path.exists(res_dir):
    os.makedirs(res_dir)

is_test = False

## SCG-Transformer model to generate synonymous codon sequences

In [133]:
def seed_everything(seed=42):
    """ Set random seeds. """

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

In [134]:
class Vocab:
    """ Dictionary of amino acid and nucleotide sequence encoding. """

    def __init__(self):
        """
        # Define special symbols and indices
        UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3

        # Make sure the tokens are in order of their indices to properly insert them in vocab
        special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
        """
        self.token2id = {}
        self.id2token = {}
        self.count = 0  # For assigning unique identifiers
        self.unk_token = '<unk>'

    def init(self, seqs, token_type, special_tokens=['<unk>', '<pad>', '<bos>', '<eos>']):
        self.token2id = {}
        self.id2token = {}
        self.count = 0

        # Adding a special marking
        if special_tokens:
            for token in special_tokens:
                self.add_token(token)

        # Segmenting the token
        if token_type == 'dna':
            tokens =  [seq[i: i+3] for seq in seqs for i in range(0, len(seq), 3)]
        elif token_type == 'protein':
            tokens = [token for seq in seqs for token in seq]

        # Sort by token frequency
        counter = collections.Counter(tokens)
        token_freqs = sorted(counter.items(),
                                   key=lambda x: x[1],
                                   reverse=True)

        # Adding a Glossary
        for token, _ in token_freqs:
            if token not in self.token2id:
                self.token2id[token] = self.count
                self.id2token[self.count] = token
                self.count += 1

    def add_token(self, token):
        if token not in self.token2id:
            self.token2id[token] = self.count
            self.id2token[self.count] = token
            self.count += 1

    def tokens_to_ids(self, tokens):
        return [self.token2id.get(token, 0) for token in tokens]

    def ids_to_tokens(self, ids):
        return [self.id2token.get(str(_id), self.unk_token) for _id in ids]

    def save(self, vocab_path):
        res = {
            'token2id': self.token2id,
            'id2token': self.id2token,
        }

        with open(vocab_path, 'w', encoding='utf-8') as json_file:
            json.dump(res, json_file)


    def load(self, vocab_path):
        with open(vocab_path, 'r') as json_file:
            loaded_data = json.load(json_file)

        self.token2id = loaded_data['token2id']
        self.id2token = loaded_data['id2token']

    def __len__(self):
        return len(self.token2id)

    def __str__(self):
        return str(list(self.token2id.items()))

In [135]:
def get_vocab(protein_vocab_path, dna_vocab_path, log_info=True):
    if not os.path.exists(protein_vocab_path):
        raise ValueError('No file')
    if not os.path.exists(dna_vocab_path):
        raise ValueError('No file')

    protein_vocab, dna_vocab = Vocab(), Vocab()
    protein_vocab.load(protein_vocab_path)
    dna_vocab.load(dna_vocab_path)

    if log_info:
        print(protein_vocab_path)
        print(dna_vocab_path)

        print(f'Vocab: {len(protein_vocab)}, {len(dna_vocab)}')
        print(protein_vocab, dna_vocab)

    return protein_vocab, dna_vocab

In [136]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000, device='cpu'):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Initialize shape to pe (positional encoding) of (max_len, d_model)
        pe = torch.zeros(max_len, d_model).to(device)

        # Initialize tensor [[0, 1, 2, 3, ...]]
        position = torch.arange(0, max_len).unsqueeze(1)

        # Here is what is in the sin and cos brackets, transformed by e and ln
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )

        # Calculate PE(pos, 2i)
        pe[:, 0::2] = torch.sin(position * div_term)

        # Calculate PE(pos, 2i+1)
        pe[:, 1::2] = torch.cos(position * div_term)

        # For ease of calculation, a batch at the outermost in unsqueeze out the
        pe = pe.unsqueeze(0)

        # If a parameter doesn't participate in gradient descent but you want to save the mod save it
        # This time you can use register_buffer
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x is the inputs after embedding, e.g. (1,7, 128), batch size is 1, 7 words, word dimension is 128
        """
        # Adds x and positional encoding
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [137]:
class Seq2SeqTransformer(nn.Module):
    """
    SCG-Transformer Network
    """

    def __init__(self,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 emb_size: int,
                 nhead: int,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,
                 max_length: int = 5000,
                 pad_idx: int = 1):
        super(Seq2SeqTransformer, self).__init__()

        # Define embedding
        self.src_embedding = nn.Embedding(src_vocab_size, emb_size, padding_idx=pad_idx)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, emb_size, padding_idx=pad_idx)

        # Define posintional encoding
        self.positional_encoding = PositionalEncoding(emb_size, dropout, max_len=max_length)

        # Define model
        self.transformer = nn.Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self,
                src: Tensor,
                tgt: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        # Encoding of src and tgt
        src = self.src_embedding(src)
        tgt = self.tgt_embedding(tgt)

        # Adding location information to src and tgt tokens
        src = self.positional_encoding(src)
        tgt = self.positional_encoding(tgt)

        outs = self.transformer(src, tgt, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)

        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        src = self.src_embedding(src)
        src = self.positional_encoding(src)

        return self.transformer.encoder(src, src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        tgt = self.tgt_embedding(tgt)
        tgt = self.positional_encoding(tgt)

        return self.transformer.decoder(tgt, memory, tgt_mask)

In [138]:
def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == False, float('-inf')).masked_fill(mask == True, float(0.0))
    return mask


def create_mask(src, tgt, pad_idx, device):
    src_seq_len = src.shape[-1]
    tgt_seq_len = tgt.shape[-1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    src_padding_mask = (src == pad_idx)
    tgt_padding_mask = (tgt == pad_idx)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


In [139]:
# Dictionary of amino acid codons with corresponding amino acids
codon_table = {
    'ATA': 'I', 'ATC': 'I', 'ATT': 'I', 'ATG': 'M',
    'ACA': 'T', 'ACC': 'T', 'ACG': 'T', 'ACT': 'T',
    'AAC': 'N', 'AAT': 'N', 'AAA': 'K', 'AAG': 'K',
    'AGC': 'S', 'AGT': 'S', 'AGA': 'R', 'AGG': 'R',
    'CTA': 'L', 'CTC': 'L', 'CTG': 'L', 'CTT': 'L',
    'CCA': 'P', 'CCC': 'P', 'CCG': 'P', 'CCT': 'P',
    'CAC': 'H', 'CAT': 'H', 'CAA': 'Q', 'CAG': 'Q',
    'CGA': 'R', 'CGC': 'R', 'CGG': 'R', 'CGT': 'R',
    'GTA': 'V', 'GTC': 'V', 'GTG': 'V', 'GTT': 'V',
    'GCA': 'A', 'GCC': 'A', 'GCG': 'A', 'GCT': 'A',
    'GAC': 'D', 'GAT': 'D', 'GAA': 'E', 'GAG': 'E',
    'GGA': 'G', 'GGC': 'G', 'GGG': 'G', 'GGT': 'G',
    'TCA': 'S', 'TCC': 'S', 'TCG': 'S', 'TCT': 'S',
    'TTC': 'F', 'TTT': 'F', 'TTA': 'L', 'TTG': 'L',
    'TAC': 'Y', 'TAT': 'Y', 'TAA': '*', 'TAG': '*',
    'TGC': 'C', 'TGT': 'C', 'TGA': '*', 'TGG': 'W',
}

def translate_dna_to_protein(dna_sequence):
    protein_sequence = ''
    for i in range(0, len(dna_sequence), 3):
        codon = dna_sequence[i:i+3].upper()
        if codon in codon_table:
            protein_sequence += codon_table[codon]
        else:
            protein_sequence += 'X'  # Unknown codons are denoted by X
    return protein_sequence

In [140]:
def translate_single_beam_search_decode(model, src, src_vocab, tgt_vocab, batch_size, beam_width, beam_size, device):
    """
    function to generate output sequence ..
    """
    model.eval()

    model = model.to(device)

    src_tokens = ['<bos>'] + list(src) + ['<eos>']
    src_tokenizer = src_vocab.tokens_to_ids(src_tokens)
    src = torch.tensor(src_tokenizer).unsqueeze(0).to(device)

    # mask
    src_len = len(src_tokens)
    src_mask = torch.zeros((src_len, src_len), device=device).type(torch.bool)

    pad_idx = src_vocab.token2id['<pad>']
    src_padding_mask = (src == pad_idx)

    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)

    tgt_bos, tgt_eos = tgt_vocab.token2id['<bos>'], tgt_vocab.token2id['<eos>']
    tgt = torch.ones(1, 1).fill_(tgt_bos).type(torch.long).to(device)

    # Use a queue to maintain the current state in the beam: sequence results, scores, whether it's finished or not
    beam_queue = [(tgt, 0, False)]

    # Use set to store states that have been traversed
    explored_states = set()

    count = 0
    for i in tqdm(range(1, len(src_tokens))):

        next_beam_queue = []
        seen_tgt = set()

        bq_len = len(beam_queue)
        for bq_start_i in range(0, bq_len, batch_size):
            bq_end_i = bq_start_i + batch_size
            if bq_end_i > bq_len:
                bq_end_i = bq_end_i

            sub_beam_queue = beam_queue[bq_start_i: bq_end_i]

            tgt_bs = []
            score_bs = []
            state_bs = []

            for (tgt, score, state) in sub_beam_queue:
                tgt_bs.append(tgt)
                score_bs.append(score)
                state_bs.append(state)

            tgt_mask = (generate_square_subsequent_mask(tgt.size(-1), device)
                        .type(torch.bool))

            tgt_bs = torch.cat(tgt_bs, dim=0).to(device)

            memory_bs = memory.repeat(len(sub_beam_queue), 1, 1).to(device)

            out = model.decode(tgt_bs, memory_bs, tgt_mask)

            # Predict the result, taking `out[:, -1]` since only the last word needs to be looked at
            probs = model.generator(out[:, -1])

            for tgt, prob, score, state in zip(tgt_bs, probs, score_bs, state_bs):
                # Get the top k candidates and their corresponding probabilities, and get the optimal result of beam_width for the current prediction
                topk_probs, topk_indices = torch.topk(prob, k=beam_width, dim=-1)

                # Processing of each candidate word
                for k in range(beam_width):
                    next_tgt = torch.cat([tgt, topk_indices[k].unsqueeze(0)], dim=-1)
                    next_tgt = next_tgt.unsqueeze(0)

                    # de-emphasize
                    if next_tgt in seen_tgt:
                        continue

                    # Calculate the new score, here cumulative probability
                    next_score = score + topk_probs[k].item()

                    # Check the last character
                    end_char_id = topk_indices[k].item()

                    # end char id -> 3 nucle char -> acid char -> acid id
                    if end_char_id != tgt_eos:
                        nucle_char = tgt_vocab.id2token[str(end_char_id)]
                        if nucle_char in ['<unk>', '<bos>', '<eos>', '<pad>']:
                            acid_char = nucle_char
                        else:
                            acid_char = codon_table[nucle_char]

                        next_state = False
                    else:
                        acid_char = '<eos>'

                        next_state = True

                    # The predicted character is not the original amino acid to be excluded
                    if acid_char != src_tokens[i]:
                        continue

                    # Take the optimal set of results
                    seen_tgt.add(next_tgt)
                    next_beam_queue.append((next_tgt, next_score, next_state))
                    if len(next_beam_queue) > beam_size:
                        next_beam_queue = sorted(next_beam_queue, key=lambda x: x[1], reverse=True)
                        next_beam_queue = next_beam_queue[:beam_size]

        # Update the current beam queue
        beam_queue = next_beam_queue

        # Check if all beams have reached the end condition
        tag = True
        for (tgt, score, state) in beam_queue:
            # 所有都达到了 eos
            if not (state == True or tgt.size(-1) >= len(src_tokens)) :
                tag = False
                break

        if tag:
            break

    #
    res = []
    for (tgt, score, state) in beam_queue:
        tgt = tgt.tolist()

        # Access to forecasts
        tgt = tgt[0]

        # ID 2 token
        tgt = tgt_vocab.ids_to_tokens(tgt)

        # Splice results
        tgt = ''.join(tgt)

        # Remove start and stop characters
        tgt = tgt.replace("<bos>", "").replace("<eos>", "")

        res.append(tgt)

    return res

In [141]:
def SCG_Transformer_predict(protein_seq, beam_batch_sizes, beam_widths, beam_sizes):
    # 1 file

    # data
    protein_vocab_path = f'{resource_dir}/ref_vocab_protein.json'
    dna_vocab_path = f'{resource_dir}/ref_vocab_dna.json'

    ref_dataset_tag = '6'
#     train_fasta_json = f'{resource_dir}/ref_e.coli_{ref_dataset_tag}.json'

    # hyper-parameters
#     resume = True
#     training = True
#     predict = not training

    PAD_IDX = 1  # UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3

    batch_size = 16

    lr = 0.0001

    curr_epoch = 1
    num_epoch = 600

    last_loss = 1e10

    # model
    d_model = 512
    nhead = 8
    num_encoder_layers = 3
    num_decoder_layers = 3
    dim_feedforward = 512
    dropout = 0.1
    max_length = 1024
    # max_length = 2048

    tag = f'tf_{lr}_{d_model}_{nhead}_{num_encoder_layers}_{num_decoder_layers}_{dim_feedforward}_{dropout}_{max_length}_{ref_dataset_tag}'

    checkpoint_path = f'{resource_dir}/{tag}.pt'
    print(f'modle: {checkpoint_path}')

    if not os.path.exists(checkpoint_path):
        raise

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    # 2 vocab
    log_info = False
    protein_vocab, dna_vocab = get_vocab(protein_vocab_path, dna_vocab_path, log_info=log_info)

    # 4 model
    model = Seq2SeqTransformer(len(protein_vocab), len(dna_vocab),
                                     emb_size=d_model, nhead=nhead,
                                     num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
                                     dim_feedforward=dim_feedforward, dropout=dropout,
                                     max_length=max_length, pad_idx=PAD_IDX)

    # need memory
    checkpoint = torch.load(checkpoint_path)

    curr_epoch = checkpoint['epoch'] + 1

    model.load_state_dict(checkpoint['model_state_dict'])

    last_loss = checkpoint['loss']

    print(f'resume train: epoch {curr_epoch}, loss {last_loss}')

    # to device
    model = model.to(device)
    print(print(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # main
    dnas_s = set()

    print('beam search args:', beam_batch_sizes, beam_widths, beam_sizes)
    for beam_batch_size, beam_width, beam_size in zip(beam_batch_sizes, beam_widths, beam_sizes):
        print('args:', beam_batch_size, beam_width, beam_size)
        dnas = translate_single_beam_search_decode(model, protein_seq, protein_vocab, dna_vocab, beam_batch_size, beam_width, beam_size, device)

        for d in dnas:
            dnas_s.add(d)

    dnas = list(dnas_s)

    print('SCG return dna num:', len(dnas))

    return dnas

# raise

In [142]:
is_test = True

In [143]:
if is_test:
    # Test SCG-Transformer model
    dnas = SCG_Transformer_predict(protein_seq, beam_batch_sizes, beam_widths, beam_sizes) # protein_seq
    print(dnas)

    # for DNABERT-2
    dna_file = f'{res_dir}/{protein_name}_scg_dnas_test.csv'

    with open(dna_file, 'w', encoding='utf-8') as w:
        w.write('name,nucle-seq,label\n')
        for i, dna in enumerate(dnas):
            w.write(f'scg_{i},{dna},-1\n')

#     raise

modle: /content/SCG-GELP/tf_0.0001_512_8_3_3_512_0.1_1024_6.pt
device: cuda
resume train: epoch 47, loss 0.9707116135501553
12707396
None
beam search args: [4] [5] [8]
args: 4 5 8


  0%|          | 0/14 [00:00<?, ?it/s]

SCG return dna num: 8
['ATGAGAATAATAAACGTAAAAGATTATGAAGAAATGTAA', 'ATGAGAATAATAAACGTGAAAGATTATGAGGAGATGTAA', 'ATGAGAATAATAAACGTGAAGGATTATGAGGAGATGTGA', 'ATGAGAATAATAAACGTGAAGGATTATGAGGAGATGTAA', 'ATGAGAATAATAAACGTGAAGGATTATGAGGAGATGTAG', 'ATGAGAATAATAAACGTGAAGGATTATGAAGAGATGTAA', 'ATGAGAATAATAAACGTGAAAGATTATGAGGAGATGTGA', 'ATGAGAATAATAAACGTGAAAGATTATGAAGAAATGTAA']


## Biological features and transfer learning (DNABER-2) feature extraction

### Biological features extraction

In [None]:
#
def KmerArray(seq, k, is_sort=False):
    if not is_sort:
        return [seq[i:i + k] for i in range(len(seq) - k + 1)]
    else:
        return [''.join(sorted(seq[i:i + k])) for i in range(len(seq) - k + 1)]


def Kmer(sequence):
    """
    return 1 x 64
    """
    kmer_n = 3

    NA = 'ACGT'

    # 64 个密码子
    kmer_name = [
        ''.join(item) for item in itertools.product(NA, repeat=kmer_n)
    ]

    # 根据 Kmer 分割 子序列
    kmer_array = KmerArray(sequence, kmer_n)
    #     print(kmer_array[:3])

    # 统计
    count = Counter(kmer_array)
    #     print(count)

    # normalized
    kmer_num = float(len(kmer_array))
    res = [count.get(key, 0) / kmer_num for key in kmer_name]
    #     print(res[:3])

    return res


def RCKmer(sequence):
    """
    return 1 x 32

    替换掉一半密码子
    """
    kmer_n = 3

    NA = 'ACGT'

    NA_dict = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}

    # 64 个密码子
    kmer_name = [
        ''.join(item) for item in itertools.product(NA, repeat=kmer_n)
    ]

    #
    rc_kmer_name = set()
    for kmer in kmer_name:
        # a[::-1]，它会从末尾开始到第一个获取每个元素。所以它颠倒了a
        rc_kmer_name.add(
            sorted([kmer, ''.join([NA_dict[nc] for nc in kmer[::-1]])])[0])

    rc_kmer_name = sorted(rc_kmer_name)

    #
    rc_kmer_dict = {}
    for kmer in rc_kmer_name:
        rc_kmer = ''.join([NA_dict[nc] for nc in kmer[::-1]])
        if kmer != rc_kmer:
            rc_kmer_dict[rc_kmer] = kmer

    # 根据 Kmer 分割 子序列
    kmer_array = KmerArray(sequence, kmer_n)

    # 替换 掉 一半 的密码子
    for i in range(len(kmer_array)):
        key = kmer_array[i]
        if key in rc_kmer_dict:
            kmer_array[i] = rc_kmer_dict[key]

    # 统计
    count = Counter(kmer_array)

    # normalized
    kmer_num = float(len(kmer_array))
    res = [count.get(key, 0) / kmer_num for key in sorted(rc_kmer_dict.values())]
    #     print(res[:3])

    return res


def Mismatch(sequence):
    """
    return 1 x 64
    """
    kmer_n = 3
    mismatch = 1

    NA = 'ACGT'

    # 64 个密码子
    kmer_name = [
        ''.join(item) for item in itertools.product(NA, repeat=kmer_n)
    ]

    kmer_dict = dict.fromkeys(kmer_name, 0)

    # 根据 Kmer 分割 子序列
    kmer_array = KmerArray(sequence, kmer_n)
    #     print(kmer_array[:3])

    for kmer in kmer_array:
        for key in kmer_dict:
            mismatch_count = sum([
                1 if kmer[i] != key[i] else 0
                for i in range(min(len(kmer), len(key)))
            ])
            if mismatch_count <= mismatch:
                kmer_dict[key] += 1

    res = [kmer_dict[k] for k in sorted(kmer_dict.keys())]

    return res


def CKSNAP(sequence):
    """
    return 1 x 64
    """
    kspace = 3
    assert len(sequence) >= kspace + 2

    NA = 'ACGT'

    # 16 个
    nn_name = [''.join(item) for item in itertools.product(NA, repeat=2)]

    res = []
    for ksp in range(kspace + 1):
        count = len(sequence) - ksp - 1
        nn_dict = dict.fromkeys(nn_name, 0)
        for i in range(count):
            nn_dict[sequence[i] + sequence[i + ksp + 1]] += 1

        res += [nn_dict[k] / count for k in nn_dict]

    return res


def PseEIIP(sequence):
    """
    return 1 x 64
    """
    NA = 'ACGT'

    EIIP_dict = {
        'A': 0.1260,
        'C': 0.1340,
        'G': 0.0806,
        'T': 0.1335,
    }

    # 64 个
    nn_name = [''.join(item) for item in itertools.product(NA, repeat=3)]

    pse_dict = dict.fromkeys(nn_name, 0)
    for k in pse_dict:
        pse_dict[k] = sum([EIIP_dict[i] for i in k])

    # 根据 Kmer 分割 子序列
    kmer_array = KmerArray(sequence, 3)
    #     print(kmer_array[:3])

    # 统计
    count = Counter(kmer_array)
    #     print(count)

    # normalized
    kmer_num = float(len(kmer_array))
    res = [pse_dict[key] * count.get(key, 0) / kmer_num for key in nn_name]

    return res


def PCP(sequence, kmer_n, dna_property_dict):
    NA = 'ACGT'

    #     kmer_n = 2

    # 统计 kmer_n 出现次数
    nn_name = [''.join(item) for item in itertools.product(NA, repeat=kmer_n)]

    nn_dict = dict.fromkeys(nn_name, 0)
    for i in range(len(sequence) - kmer_n + 1):
        nn_dict[sequence[i:i + kmer_n]] += 1

    sum_val = sum(list(nn_dict.values()))
    nn_val = [nn_dict[k] / sum_val for k in nn_name]

    # 和预设值相乘
    res = []
    for k in dna_property_dict.keys():
        res_ = np.multiply(nn_val, dna_property_dict[k])

        res += list(res_)

    return res


def DPCP(sequence):
    """
    return 1 x 96(16*6)
    """
    dna_property_dict = {
        'Twist': [
            0.063, 1.502, 0.783, 1.071, -1.376, 0.063, -1.664, 0.783, -0.081,
            -0.081, 0.063, 1.502, -1.233, -0.081, -1.376, 0.063
        ],
        'Tilt': [
            0.502, 0.502, 0.359, 0.215, -1.364, 1.077, -1.22, 0.359, 0.502,
            0.215, 1.077, 0.502, -2.368, 0.502, -1.364, 0.502
        ],
        'Roll': [
            0.09, 1.19, -0.28, 0.83, -1.01, -0.28, -1.38, -0.28, 0.09, 2.3,
            -0.28, 1.19, -1.38, 0.09, -1.01, 0.09
        ],
        'Shift': [
            1.587, 0.126, 0.679, -1.019, -0.861, 0.56, -0.822, 0.679, 0.126,
            -0.348, 0.56, 0.126, -2.243, 0.126, -0.861, 1.587
        ],
        'Slide': [
            0.111, 1.289, -0.241, 2.513, -0.623, -0.822, -0.287, -0.241,
            -0.394, 0.646, -0.822, 1.289, -1.511, -0.394, -0.623, 0.111
        ],
        'Rise': [
            -0.109, 1.044, -0.623, 1.171, -1.254, 0.242, -1.389, -0.623, 0.711,
            1.585, 0.242, 1.044, -1.389, 0.711, -1.254, -0.109
        ]
    }

    kmer_n = 2
    return PCP(sequence, kmer_n, dna_property_dict)


def TPCP(sequence):
    """
    return 1 x 768(64*12)
    """
    dna_property_dict = {
        'Bendability (DNAse)': [
            -2.087, -1.509, -0.506, -2.126, 0.111, -0.121, -0.121, -1.354,
            0.381, 0.304, -0.313, -1.354, 1.615, -0.737, 1.229, -2.126, 0.265,
            0.496, 1.576, 1.229, -1.856, 0.072, -0.969, -0.313, 0.111, -0.468,
            -0.969, -0.121, 0.882, 0.419, 1.576, -0.506, -0.159, 0.034, 0.419,
            -0.737, 0.766, 1.036, -0.468, 0.304, 0.265, 1.036, 0.072, -0.121,
            0.342, 0.034, 0.496, -1.509, 0.689, 0.342, 0.882, 1.615, 1.73,
            0.265, 0.111, 0.381, 1.73, 0.766, -1.856, 0.111, 0.689, -0.159,
            0.265, -2.087
        ],
        'Bendability (consensus)': [
            -2.745, -1.354, -0.257, -2.585, 0.171, 0.064, 0.064, -0.685, -0.15,
            0.92, -0.07, -0.685, 0.572, -0.391, 1.348, -2.585, -0.231, 0.786,
            0.92, 1.348, -1.14, 0.358, -0.712, -0.07, 1.0, 0.385, -0.712,
            0.064, -0.097, 0.438, 0.92, -0.257, -0.605, 0.171, 0.438, -0.391,
            0.839, 2.097, 0.385, 0.92, -0.097, 2.097, 0.358, 0.064, -0.07,
            0.171, 0.786, -1.354, -0.284, -0.07, -0.097, 0.572, 1.348, -0.097,
            1.0, -0.15, 1.348, 0.839, -1.14, 0.171, -0.284, -0.605, -0.231,
            -2.745
        ],
        'Trinucleotide GC Content': [
            -1.732, -0.577, -0.577, -1.732, -0.577, 0.577, 0.577, -0.577,
            -0.577, 0.577, 0.577, -0.577, -1.732, -0.577, -0.577, -1.732,
            -0.577, 0.577, 0.577, -0.577, 0.577, 1.732, 1.732, 0.577, 0.577,
            1.732, 1.732, 0.577, -0.577, 0.577, 0.577, -0.577, -0.577, 0.577,
            0.577, -0.577, 0.577, 1.732, 1.732, 0.577, 0.577, 1.732, 1.732,
            0.577, -0.577, 0.577, 0.577, -0.577, -1.732, -0.577, -0.577,
            -1.732, -0.577, 0.577, 0.577, -0.577, -0.577, 0.577, 0.577, -0.577,
            -1.732, -0.577, -0.577, -1.732
        ],
        'Nucleosome positioning': [
            -2.349, -0.561, 0.155, -1.991, 0.155, 0.274, 0.274, 0.453, -0.74,
            1.287, 0.274, 0.453, -0.978, 0.214, 0.87, -1.991, -0.74, 0.81,
            -0.322, 0.87, 0.274, 0.572, -0.084, 0.274, 1.645, 1.287, -0.084,
            0.274, -1.276, 0.274, -0.322, 0.155, -0.918, 0.274, 0.274, 0.214,
            0.572, 2.479, 1.287, 1.287, -0.501, 2.479, 0.572, 0.274, -0.561,
            0.274, 0.81, -0.561, -1.395, -0.561, -1.276, -0.978, 0.274, -0.501,
            1.645, -0.74, 0.274, 0.572, 0.274, 0.155, -1.395, -0.918, -0.74,
            -2.349
        ],
        'Consensus_roll': [
            -2.744, -1.363, -0.26, -2.591, 0.164, 0.071, 0.065, -0.676, -0.158,
            0.911, -0.07, -0.676, 0.584, -0.397, 1.358, -2.591, -0.226, 0.773,
            0.92, 1.358, -1.139, 0.345, -0.705, -0.07, 1.012, 0.379, -0.705,
            0.065, -0.097, 0.427, 0.92, -0.26, -0.6, 0.178, 0.427, -0.397,
            0.842, 2.089, 0.379, 0.911, -0.103, 2.089, 0.345, 0.071, -0.062,
            0.178, 0.773, -1.363, -0.275, -0.062, -0.097, 0.584, 1.348, -0.103,
            1.012, -0.158, 1.348, 0.842, -1.139, 0.164, -0.275, -0.6, -0.226,
            -2.744
        ],
        'Consensus-Rigid': [
            -2.744, -1.363, -0.26, -2.591, 0.164, 0.071, 0.065, -0.676, -0.158,
            0.911, -0.07, -0.676, 0.584, -0.397, 1.358, -2.591, -0.226, 0.773,
            0.92, 1.358, -1.139, 0.345, -0.705, -0.07, 1.012, 0.379, -0.705,
            0.065, -0.097, 0.427, 0.92, -0.26, -0.6, 0.178, 0.427, -0.397,
            0.842, 2.089, 0.379, 0.911, -0.103, 2.089, 0.345, 0.071, -0.062,
            0.178, 0.773, -1.363, -0.275, -0.062, -0.097, 0.584, 1.348, -0.103,
            1.012, -0.158, 1.348, 0.842, -1.139, 0.164, -0.275, -0.6, -0.226,
            -2.744
        ],
        'Dnase I': [
            2.274, 1.105, 0.193, 2.141, -0.153, -0.078, -0.074, 0.536, 0.109,
            -0.753, 0.039, 0.536, -0.491, 0.307, -1.112, 2.141, 0.166, -0.646,
            -0.762, -1.112, 0.917, -0.3, 0.558, 0.039, -0.834, -0.326, 0.558,
            -0.074, 0.062, -0.365, -0.762, 0.193, 0.474, -0.165, -0.365, 0.307,
            -0.702, -1.687, -0.326, -0.753, 0.066, -1.687, -0.3, -0.078, 0.031,
            -0.165, -0.646, 1.105, 0.206, 0.031, 0.062, -0.491, -1.103, 0.066,
            -0.834, 0.109, 4.522, -0.702, 0.917, -0.153, 0.206, 0.474, 0.166,
            -2.615
        ],
        'Dnase I-Rigid': [
            2.118, 1.516, 0.493, 2.158, -0.123, 0.107, 0.107, 1.357, -0.389,
            -0.313, 0.3, 1.357, -1.585, 0.727, -1.215, 2.158, -0.275, -0.503,
            -1.549, -1.215, 1.876, -0.084, 0.962, 0.3, -0.123, 0.455, 0.962,
            0.107, -0.88, -0.427, -1.549, 0.493, 0.146, -0.046, -0.427, 0.727,
            -0.767, -1.029, 0.455, -0.313, -0.275, -1.029, -0.084, 0.107,
            -0.351, -0.046, -0.503, 1.516, -0.692, -0.351, -0.88, -1.585,
            -1.696, -0.275, -0.123, -0.389, -1.696, -0.767, 1.876, -0.123,
            -0.692, 0.146, -0.275, 2.118
        ],
        'MW-Daltons': [
            -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
            -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0,
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0,
            -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
            -1.0, -1.0, -1.0, -1.0
        ],
        'MW-kg': [
            -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
            -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0,
            1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0,
            -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
            -1.0, -1.0, -1.0, -1.0
        ],
        'Nucleosome': [
            -2.342, -0.555, 0.169, -2.004, 0.169, 0.266, 0.266, 0.459, -0.748,
            1.28, 0.266, 0.459, -0.99, 0.217, 0.893, -2.004, -0.748, 0.797,
            -0.314, 0.893, 0.266, 0.555, -0.072, 0.266, 1.666, 1.28, -0.072,
            0.266, -1.28, 0.266, -0.314, 0.169, -0.893, 0.266, 0.266, 0.217,
            0.555, 2.487, 1.28, 1.28, -0.507, 2.487, 0.555, 0.266, -0.555,
            0.266, 0.797, -0.555, -1.376, -0.555, -1.28, -0.99, 0.266, -0.507,
            1.666, -0.748, 0.266, 0.555, 0.266, 0.169, -1.376, -0.893, -0.748,
            -2.342
        ],
        'Nucleosome-Rigid': [
            2.386, 0.548, -0.179, 2.032, -0.179, -0.275, -0.275, -0.466, 0.743,
            -1.272, -0.275, -0.466, 0.988, -0.227, -0.894, 2.032, 0.743, -0.8,
            0.304, -0.894, -0.275, -0.562, 0.062, -0.275, -1.646, -1.272,
            0.062, -0.275, 1.285, -0.275, 0.304, -0.179, 0.89, -0.275, -0.275,
            -0.227, -0.562, -2.433, -1.272, -1.272, 0.499, -2.433, -0.562,
            -0.275, 0.548, -0.275, -0.8, 0.548, 1.384, 0.548, 1.285, 0.988,
            -0.275, 0.499, -1.646, 0.743, -0.275, -0.562, -0.275, -0.179,
            1.384, 0.89, 0.743, 2.386
        ]
    }

    kmer_n = 3
    return PCP(sequence, kmer_n, dna_property_dict)


def MMI(sequence):
    NN = 'ACGT'

    nn_2_name = [''.join(sorted(item)) for item in itertools.product(NN, repeat=2)]
    nn_3_name = [
        ''.join(sorted(item)) for item in itertools.product(NN, repeat=3)
    ]

    nn_2_name = list(set(nn_2_name))
    nn_3_name = list(set(nn_3_name))

    nn_2_arr = KmerArray(sequence, 2, is_sort=True)
    nn_3_arr = KmerArray(sequence, 3, is_sort=True)

    # 统计
    count_1 = Counter(sequence)
    count_2 = Counter(nn_2_arr)
    count_3 = Counter(nn_3_arr)

    for nn, co in zip([NN, nn_2_name, nn_3_name], [count_1, count_2, count_3]):
        for nn_ in nn:
            if nn_ not in co:
                co[nn_] = 0
    # 归一
    seq_len = len(sequence)
    for idx, c_d in enumerate([count_1, count_2, count_3]):
        for k in c_d.keys():
            c_d[k] /= float(seq_len - idx)

    # 计算
    def val_1(k1, k2):
        if (count_1[k1] * count_1[k2]) != 0 and count_2[k1 + k2] != 0:
            return count_2[k1 + k2] * math.log(count_2[k1 + k2] / (count_1[k1] * count_1[k2]))

        return 0

    def val_2(k1, k2):
        if count_2[k1 + k2] != 0 and count_1[k2] != 0:
            return (count_2[k1 + k2] / count_1[k2]) * math.log(count_2[k1 + k2] / count_1[k2])

        return 0

    def val_3(k1, k2, k3):
        if count_3[k1 + k2 + k3] != 0 and count_2[k2 + k3] != 0:
            return (count_3[k1 + k2 + k3] / count_2[k2 + k3]) * math.log(count_3[k1 + k2 + k3] / count_2[k2 + k3])

        return 0

    res = []
    for k in nn_2_name:
        if k in count_2.keys():
            res.append(val_1(k[0], k[1]))

    for k in nn_3_name:
        if k in count_3.keys():
            res.append(val_1(k[0], k[1]) + val_2(k[0], k[2]) - val_3(k[0], k[1], k[2]))

    return res


def Z_curve_9bit(sequence):
    res = []
    pos1_dict = {}
    pos2_dict = {}
    pos3_dict = {}
    for i in range(len(sequence)):
        if (i + 1) % 3 == 1:
            if sequence[i] in pos1_dict:
                pos1_dict[sequence[i]] += 1
            else:
                pos1_dict[sequence[i]] = 1
        elif (i + 1) % 3 == 2:
            if sequence[i] in pos2_dict:
                pos2_dict[sequence[i]] += 1
            else:
                pos2_dict[sequence[i]] = 1
        elif (i + 1) % 3 == 0:
            if sequence[i] in pos3_dict:
                pos3_dict[sequence[i]] += 1
            else:
                pos3_dict[sequence[i]] = 1

    res += [(pos1_dict.get('A', 0) + pos1_dict.get('G', 0) -
             pos1_dict.get('C', 0) - pos1_dict.get('T', 0)) / len(sequence),
            (pos1_dict.get('A', 0) + pos1_dict.get('C', 0) -
             pos1_dict.get('G', 0) - pos1_dict.get('T', 0)) / len(sequence),
            (pos1_dict.get('A', 0) + pos1_dict.get('T', 0) -
             pos1_dict.get('G', 0) - pos1_dict.get('C', 0)) / len(sequence)]
    res += [(pos2_dict.get('A', 0) + pos2_dict.get('G', 0) -
             pos2_dict.get('C', 0) - pos2_dict.get('T', 0)) / len(sequence),
            (pos2_dict.get('A', 0) + pos2_dict.get('C', 0) -
             pos2_dict.get('G', 0) - pos2_dict.get('T', 0)) / len(sequence),
            (pos2_dict.get('A', 0) + pos2_dict.get('T', 0) -
             pos2_dict.get('G', 0) - pos2_dict.get('C', 0)) / len(sequence)]
    res += [(pos3_dict.get('A', 0) + pos3_dict.get('G', 0) -
             pos3_dict.get('C', 0) - pos3_dict.get('T', 0)) / len(sequence),
            (pos3_dict.get('A', 0) + pos3_dict.get('C', 0) -
             pos3_dict.get('G', 0) - pos3_dict.get('T', 0)) / len(sequence),
            (pos3_dict.get('A', 0) + pos3_dict.get('T', 0) -
             pos3_dict.get('G', 0) - pos3_dict.get('C', 0)) / len(sequence)]

    return res


def Z_curve_12bit(sequence):
    res = []
    pos_dict = {}
    for i in range(len(sequence) - 1):
        if sequence[i:i + 2] in pos_dict:
            pos_dict[sequence[i:i + 2]] += 1
        else:
            pos_dict[sequence[i:i + 2]] = 1

    NN = 'ACGT'
    for base in NN:
        res += [
            (pos_dict.get('%sA' % base, 0) + pos_dict.get('%sG' % base, 0) -
             pos_dict.get('%sC' % base, 0) - pos_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),  # x
            (pos_dict.get('%sA' % base, 0) + pos_dict.get('%sC' % base, 0) -
             pos_dict.get('%sG' % base, 0) - pos_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),  # y
            (pos_dict.get('%sA' % base, 0) + pos_dict.get('%sT' % base, 0) -
             pos_dict.get('%sG' % base, 0) - pos_dict.get('%sC' % base, 0)) /
            (len(sequence) - 1)
        ]

    return res


def Z_curve_36bit(sequence):
    res = []
    pos1_dict = {}
    pos2_dict = {}
    pos3_dict = {}
    for i in range(len(sequence) - 1):
        if (i + 1) % 3 == 1:
            if sequence[i:i + 2] in pos1_dict:
                pos1_dict[sequence[i:i + 2]] += 1
            else:
                pos1_dict[sequence[i:i + 2]] = 1
        elif (i + 1) % 3 == 2:
            if sequence[i:i + 2] in pos2_dict:
                pos2_dict[sequence[i:i + 2]] += 1
            else:
                pos2_dict[sequence[i:i + 2]] = 1
        elif (i + 1) % 3 == 0:
            if sequence[i:i + 2] in pos3_dict:
                pos3_dict[sequence[i:i + 2]] += 1
            else:
                pos3_dict[sequence[i:i + 2]] = 1

    NN = 'ACGT'
    for base in NN:
        res += [
            (pos1_dict.get('%sA' % base, 0) + pos1_dict.get('%sG' % base, 0) -
             pos1_dict.get('%sC' % base, 0) - pos1_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),  # x
            (pos1_dict.get('%sA' % base, 0) + pos1_dict.get('%sC' % base, 0) -
             pos1_dict.get('%sG' % base, 0) - pos1_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),  # y
            (pos1_dict.get('%sA' % base, 0) + pos1_dict.get('%sT' % base, 0) -
             pos1_dict.get('%sG' % base, 0) - pos1_dict.get('%sC' % base, 0)) /
            (len(sequence) - 1)  # z
        ]
        res += [
            (pos2_dict.get('%sA' % base, 0) + pos2_dict.get('%sG' % base, 0) -
             pos2_dict.get('%sC' % base, 0) - pos2_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),
            (pos2_dict.get('%sA' % base, 0) + pos2_dict.get('%sC' % base, 0) -
             pos2_dict.get('%sG' % base, 0) - pos2_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),
            (pos2_dict.get('%sA' % base, 0) + pos2_dict.get('%sT' % base, 0) -
             pos2_dict.get('%sG' % base, 0) - pos2_dict.get('%sC' % base, 0)) /
            (len(sequence) - 1)
        ]
        res += [
            (pos3_dict.get('%sA' % base, 0) + pos3_dict.get('%sG' % base, 0) -
             pos3_dict.get('%sC' % base, 0) - pos3_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),
            (pos3_dict.get('%sA' % base, 0) + pos3_dict.get('%sC' % base, 0) -
             pos3_dict.get('%sG' % base, 0) - pos3_dict.get('%sT' % base, 0)) /
            (len(sequence) - 1),
            (pos3_dict.get('%sA' % base, 0) + pos3_dict.get('%sT' % base, 0) -
             pos3_dict.get('%sG' % base, 0) - pos3_dict.get('%sC' % base, 0)) /
            (len(sequence) - 1)
        ]

    return res


def Z_curve_48bit(sequence):
    res = []
    pos_dict = {}
    for i in range(len(sequence) - 2):
        if sequence[i:i + 3] in pos_dict:
            pos_dict[sequence[i:i + 3]] += 1
        else:
            pos_dict[sequence[i:i + 3]] = 1

    NN = 'ACGT'
    for base in NN:
        for base1 in NN:
            res += [
                (pos_dict.get('%s%sA' % (base, base1), 0) +
                 pos_dict.get('%s%sG' % (base, base1), 0) - pos_dict.get(
                     '%s%sC' %
                     (base, base1), 0) - pos_dict.get('%s%sT' %
                                                      (base, base1), 0)) /
                (len(sequence) - 2),  # x
                (pos_dict.get('%s%sA' % (base, base1), 0) + pos_dict.get(
                    '%s%sC' %
                    (base, base1), 0) - pos_dict.get('%s%sG' %
                                                     (base, base1), 0) -
                 pos_dict.get('%s%sT' %
                              (base, base1), 0)) / (len(sequence) - 2),  # y
                (pos_dict.get('%s%sA' % (base, base1), 0) + pos_dict.get(
                    '%s%sT' %
                    (base, base1), 0) - pos_dict.get('%s%sG' %
                                                     (base, base1), 0) -
                 pos_dict.get('%s%sC' %
                              (base, base1), 0)) / (len(sequence) - 2)  # z
            ]

    return res


def Z_curve_144bit(sequence):
    res = []
    pos1_dict = {}
    pos2_dict = {}
    pos3_dict = {}
    for i in range(len(sequence) - 2):
        if (i + 1) % 3 == 1:
            if sequence[i:i + 3] in pos1_dict:
                pos1_dict[sequence[i:i + 3]] += 1
            else:
                pos1_dict[sequence[i:i + 3]] = 1
        elif (i + 1) % 3 == 2:
            if sequence[i:i + 3] in pos2_dict:
                pos2_dict[sequence[i:i + 3]] += 1
            else:
                pos2_dict[sequence[i:i + 3]] = 1
        elif (i + 1) % 3 == 0:
            if sequence[i:i + 3] in pos3_dict:
                pos3_dict[sequence[i:i + 3]] += 1
            else:
                pos3_dict[sequence[i:i + 3]] = 1

    NN = 'ACGT'
    for base in NN:
        for base1 in NN:
            res += [
                (pos1_dict.get('%s%sA' % (base, base1), 0) +
                 pos1_dict.get('%s%sG' % (base, base1), 0) - pos1_dict.get(
                     '%s%sC' %
                     (base, base1), 0) - pos1_dict.get('%s%sT' %
                                                       (base, base1), 0)) /
                (len(sequence) - 2),  # x
                (pos1_dict.get('%s%sA' % (base, base1), 0) + pos1_dict.get(
                    '%s%sC' %
                    (base, base1), 0) - pos1_dict.get('%s%sG' %
                                                      (base, base1), 0) -
                 pos1_dict.get('%s%sT' %
                               (base, base1), 0)) / (len(sequence) - 2),  # y
                (pos1_dict.get('%s%sA' % (base, base1), 0) + pos1_dict.get(
                    '%s%sT' %
                    (base, base1), 0) - pos1_dict.get('%s%sG' %
                                                      (base, base1), 0) -
                 pos1_dict.get('%s%sC' %
                               (base, base1), 0)) / (len(sequence) - 2)  # z
            ]
            res += [
                (pos2_dict.get('%s%sA' % (base, base1), 0) +
                 pos2_dict.get('%s%sG' % (base, base1), 0) - pos2_dict.get(
                     '%s%sC' %
                     (base, base1), 0) - pos2_dict.get('%s%sT' %
                                                       (base, base1), 0)) /
                (len(sequence) - 2),  # x
                (pos2_dict.get('%s%sA' % (base, base1), 0) + pos2_dict.get(
                    '%s%sC' %
                    (base, base1), 0) - pos2_dict.get('%s%sG' %
                                                      (base, base1), 0) -
                 pos2_dict.get('%s%sT' %
                               (base, base1), 0)) / (len(sequence) - 2),  # y
                (pos2_dict.get('%s%sA' % (base, base1), 0) + pos2_dict.get(
                    '%s%sT' %
                    (base, base1), 0) - pos2_dict.get('%s%sG' %
                                                      (base, base1), 0) -
                 pos2_dict.get('%s%sC' %
                               (base, base1), 0)) / (len(sequence) - 2)  # z
            ]
            res += [
                (pos3_dict.get('%s%sA' % (base, base1), 0) +
                 pos3_dict.get('%s%sG' % (base, base1), 0) - pos3_dict.get(
                     '%s%sC' %
                     (base, base1), 0) - pos3_dict.get('%s%sT' %
                                                       (base, base1), 0)) /
                (len(sequence) - 2),  # x
                (pos3_dict.get('%s%sA' % (base, base1), 0) + pos3_dict.get(
                    '%s%sC' %
                    (base, base1), 0) - pos3_dict.get('%s%sG' %
                                                      (base, base1), 0) -
                 pos3_dict.get('%s%sT' %
                               (base, base1), 0)) / (len(sequence) - 2),  # y
                (pos3_dict.get('%s%sA' % (base, base1), 0) + pos3_dict.get(
                    '%s%sT' %
                    (base, base1), 0) - pos3_dict.get('%s%sG' %
                                                      (base, base1), 0) -
                 pos3_dict.get('%s%sC' %
                               (base, base1), 0)) / (len(sequence) - 2)  # z
            ]

    return res


def PseDNC(sequence):
    lamada_val = 3
    weight = 0.05

    kmer_n = 2

    property_name = ['Twist', 'Tilt', 'Roll', 'Shift', 'Slide', 'Rise']
    property_value = {
        'Twist': [
            0.063, 1.502, 0.783, 1.071, -1.376, 0.063, -1.664, 0.783, -0.081,
            -0.081, 0.063, 1.502, -1.233, -0.081, -1.376, 0.063
        ],
        'Tilt': [
            0.502, 0.502, 0.359, 0.215, -1.364, 1.077, -1.22, 0.359, 0.502,
            0.215, 1.077, 0.502, -2.368, 0.502, -1.364, 0.502
        ],
        'Roll': [
            0.09, 1.19, -0.28, 0.83, -1.01, -0.28, -1.38, -0.28, 0.09, 2.3,
            -0.28, 1.19, -1.38, 0.09, -1.01, 0.09
        ],
        'Shift': [
            1.587, 0.126, 0.679, -1.019, -0.861, 0.56, -0.822, 0.679, 0.126,
            -0.348, 0.56, 0.126, -2.243, 0.126, -0.861, 1.587
        ],
        'Slide': [
            0.111, 1.289, -0.241, 2.513, -0.623, -0.822, -0.287, -0.241,
            -0.394, 0.646, -0.822, 1.289, -1.511, -0.394, -0.623, 0.111
        ],
        'Rise': [
            -0.109, 1.044, -0.623, 1.171, -1.254, 0.242, -1.389, -0.623, 0.711,
            1.585, 0.242, 1.044, -1.389, 0.711, -1.254, -0.109
        ]
    }

    NN = 'ACGT'
    nn_name = [''.join(item) for item in itertools.product(NN, repeat=kmer_n)]

    nn_arr = KmerArray(sequence, kmer_n)
    count = Counter(nn_arr)

    for k in count:
        count[k] /= len(sequence) - kmer_n + 1

    theta_arr = []
    for tmp_lamada in range(lamada_val):
        theta = 0
        for i in range(len(sequence) - tmp_lamada - kmer_n):
            kk = 0
            for p_name in property_name:
                kk += (property_value[p_name][nn_name.index(
                    sequence[i:i + kmer_n])] -
                       property_value[p_name][nn_name.index(
                           sequence[i + tmp_lamada + 1:i + tmp_lamada + 1 +
                                    kmer_n])])**2

            theta += kk / len(property_name)
        theta_arr.append(theta / (len(sequence) - tmp_lamada - kmer_n))

    res = []
    for k in nn_name:
        res.append(count[k] / (1 + weight * sum(theta_arr)))

    start_i = len(NN)**kmer_n + 1
    for k in range(start_i, start_i + lamada_val):
        res.append(
            (weight * theta_arr[k - start_i]) / (1 + weight * sum(theta_arr)))

    return res


def PseKNC(sequence):
    """
    return 1 x 66
    """
    lamada_val = 2
    weight = 0.1

    kmer_n = 3

    property_name = ['Twist', 'Tilt', 'Roll', 'Shift', 'Slide', 'Rise']
    property_value = {
        'Twist': [
            0.063, 1.502, 0.783, 1.071, -1.376, 0.063, -1.664, 0.783, -0.081,
            -0.081, 0.063, 1.502, -1.233, -0.081, -1.376, 0.063
        ],
        'Tilt': [
            0.502, 0.502, 0.359, 0.215, -1.364, 1.077, -1.22, 0.359, 0.502,
            0.215, 1.077, 0.502, -2.368, 0.502, -1.364, 0.502
        ],
        'Roll': [
            0.09, 1.19, -0.28, 0.83, -1.01, -0.28, -1.38, -0.28, 0.09, 2.3,
            -0.28, 1.19, -1.38, 0.09, -1.01, 0.09
        ],
        'Shift': [
            1.587, 0.126, 0.679, -1.019, -0.861, 0.56, -0.822, 0.679, 0.126,
            -0.348, 0.56, 0.126, -2.243, 0.126, -0.861, 1.587
        ],
        'Slide': [
            0.111, 1.289, -0.241, 2.513, -0.623, -0.822, -0.287, -0.241,
            -0.394, 0.646, -0.822, 1.289, -1.511, -0.394, -0.623, 0.111
        ],
        'Rise': [
            -0.109, 1.044, -0.623, 1.171, -1.254, 0.242, -1.389, -0.623, 0.711,
            1.585, 0.242, 1.044, -1.389, 0.711, -1.254, -0.109
        ]
    }

    NN = 'ACGT'
    nn_2_name = [''.join(item) for item in itertools.product(NN, repeat=2)]
    nn_3_name = [
        ''.join(item) for item in itertools.product(NN, repeat=kmer_n)
    ]

    nn_3_arr = KmerArray(sequence, kmer_n)
    count_3 = Counter(nn_3_arr)

    for k in count_3:
        count_3[k] /= len(sequence) - kmer_n + 1

    theta_arr = []
    kmer_n_ = 2
    for tmp_lamada in range(lamada_val):
        theta = 0
        for i in range(len(sequence) - tmp_lamada - kmer_n_):
            kk = 0
            for p_name in property_name:
                kk += (property_value[p_name][nn_2_name.index(
                    sequence[i:i + kmer_n_])] -
                       property_value[p_name][nn_2_name.index(
                           sequence[i + tmp_lamada + 1:i + tmp_lamada + 1 +
                                    kmer_n_])])**2

            theta += kk / len(property_name)
        theta_arr.append(theta / (len(sequence) - tmp_lamada - kmer_n_))

    res = []
    for k in nn_3_name:
        res.append(count_3[k] / (1 + weight * sum(theta_arr)))

    start_i = len(NN)**kmer_n + 1
    for k in range(start_i, start_i + lamada_val):
        res.append(
            (weight * theta_arr[k - start_i]) / (1 + weight * sum(theta_arr)))

    return res


def PCPseTNC(sequence):
    lamada_val = 3
    weight = 0.05

    kmer_n = 3

    property_name = ['Dnase I', 'Bendability (DNAse)']
    property_value = {
        'Dnase I': [
            2.274, 1.105, 0.193, 2.141, -0.153, -0.078, -0.074, 0.536, 0.109,
            -0.753, 0.039, 0.536, -0.491, 0.307, -1.112, 2.141, 0.166, -0.646,
            -0.762, -1.112, 0.917, -0.3, 0.558, 0.039, -0.834, -0.326, 0.558,
            -0.074, 0.062, -0.365, -0.762, 0.193, 0.474, -0.165, -0.365, 0.307,
            -0.702, -1.687, -0.326, -0.753, 0.066, -1.687, -0.3, -0.078, 0.031,
            -0.165, -0.646, 1.105, 0.206, 0.031, 0.062, -0.491, -1.103, 0.066,
            -0.834, 0.109, 4.522, -0.702, 0.917, -0.153, 0.206, 0.474, 0.166,
            -2.615
        ],
        'Bendability (DNAse)': [
            -2.087, -1.509, -0.506, -2.126, 0.111, -0.121, -0.121, -1.354,
            0.381, 0.304, -0.313, -1.354, 1.615, -0.737, 1.229, -2.126, 0.265,
            0.496, 1.576, 1.229, -1.856, 0.072, -0.969, -0.313, 0.111, -0.468,
            -0.969, -0.121, 0.882, 0.419, 1.576, -0.506, -0.159, 0.034, 0.419,
            -0.737, 0.766, 1.036, -0.468, 0.304, 0.265, 1.036, 0.072, -0.121,
            0.342, 0.034, 0.496, -1.509, 0.689, 0.342, 0.882, 1.615, 1.73,
            0.265, 0.111, 0.381, 1.73, 0.766, -1.856, 0.111, 0.689, -0.159,
            0.265, -2.087
        ]
    }

    NN = 'ACGT'
    nn_name = [''.join(item) for item in itertools.product(NN, repeat=kmer_n)]

    nn_arr = KmerArray(sequence, kmer_n)
    count = Counter(nn_arr)

    for k in count:
        count[k] /= len(sequence) - kmer_n + 1

    theta_arr = []
    kmer_n_ = kmer_n
    for tmp_lamada in range(lamada_val):
        theta = 0
        for i in range(len(sequence) - tmp_lamada - kmer_n_):
            kk = 0
            for p_name in property_name:
                kk += (property_value[p_name][nn_name.index(
                    sequence[i:i + kmer_n_])] -
                       property_value[p_name][nn_name.index(
                           sequence[i + tmp_lamada + 1:i + tmp_lamada + 1 +
                                    kmer_n_])])**2

            theta += kk / len(property_name)
        theta_arr.append(theta / (len(sequence) - tmp_lamada - kmer_n_))

    res = []
    for k in nn_name:
        res.append(count[k] / (1 + weight * sum(theta_arr)))

    start_i = len(NN)**kmer_n + 1
    for k in range(start_i, start_i + lamada_val):
        res.append(
            (weight * theta_arr[k - start_i]) / (1 + weight * sum(theta_arr)))

    return res


def SCPseDNC(sequence):
    """
    return 1 x 34
    """
    lamada_val = 3
    weight = 0.05

    kmer_n = 2

    property_name = ['Twist', 'Tilt', 'Roll', 'Shift', 'Slide', 'Rise']
    property_value = {
        'Twist': [
            0.063, 1.502, 0.783, 1.071, -1.376, 0.063, -1.664, 0.783, -0.081,
            -0.081, 0.063, 1.502, -1.233, -0.081, -1.376, 0.063
        ],
        'Tilt': [
            0.502, 0.502, 0.359, 0.215, -1.364, 1.077, -1.22, 0.359, 0.502,
            0.215, 1.077, 0.502, -2.368, 0.502, -1.364, 0.502
        ],
        'Roll': [
            0.09, 1.19, -0.28, 0.83, -1.01, -0.28, -1.38, -0.28, 0.09, 2.3,
            -0.28, 1.19, -1.38, 0.09, -1.01, 0.09
        ],
        'Shift': [
            1.587, 0.126, 0.679, -1.019, -0.861, 0.56, -0.822, 0.679, 0.126,
            -0.348, 0.56, 0.126, -2.243, 0.126, -0.861, 1.587
        ],
        'Slide': [
            0.111, 1.289, -0.241, 2.513, -0.623, -0.822, -0.287, -0.241,
            -0.394, 0.646, -0.822, 1.289, -1.511, -0.394, -0.623, 0.111
        ],
        'Rise': [
            -0.109, 1.044, -0.623, 1.171, -1.254, 0.242, -1.389, -0.623, 0.711,
            1.585, 0.242, 1.044, -1.389, 0.711, -1.254, -0.109
        ]
    }

    NN = 'ACGT'
    nn_name = [''.join(item) for item in itertools.product(NN, repeat=kmer_n)]

    nn_arr = KmerArray(sequence, kmer_n)
    count = Counter(nn_arr)

    for k in count:
        count[k] /= len(sequence) - kmer_n + 1

    theta_arr = []
    kmer_n_ = kmer_n
    for tmp_lamada in range(lamada_val):
        for p_name in property_name:
            theta = 0
            for i in range(len(sequence) - tmp_lamada - kmer_n_):
                theta += (property_value[p_name][nn_name.index(
                    sequence[i:i + kmer_n_])] *
                          property_value[p_name][nn_name.index(
                              sequence[i + tmp_lamada + 1:i + tmp_lamada + 1 +
                                       kmer_n_])])

            theta_arr.append(theta / (len(sequence) - tmp_lamada - kmer_n_))

    res = []
    for k in nn_name:
        res.append(count[k] / (1 + weight * sum(theta_arr)))

    start_i = len(NN)**kmer_n + 1
    for k in range(start_i, start_i + lamada_val * len(property_name)):
        res.append(
            (weight * theta_arr[k - start_i]) / (1 + weight * sum(theta_arr)))

    return res


def SCPseTNC(sequence):
    lamada_val = 3
    weight = 0.05

    kmer_n = 3

    property_name = ['Dnase I', 'Bendability (DNAse)']
    property_value = {
        'Dnase I': [
            2.274, 1.105, 0.193, 2.141, -0.153, -0.078, -0.074, 0.536, 0.109,
            -0.753, 0.039, 0.536, -0.491, 0.307, -1.112, 2.141, 0.166, -0.646,
            -0.762, -1.112, 0.917, -0.3, 0.558, 0.039, -0.834, -0.326, 0.558,
            -0.074, 0.062, -0.365, -0.762, 0.193, 0.474, -0.165, -0.365, 0.307,
            -0.702, -1.687, -0.326, -0.753, 0.066, -1.687, -0.3, -0.078, 0.031,
            -0.165, -0.646, 1.105, 0.206, 0.031, 0.062, -0.491, -1.103, 0.066,
            -0.834, 0.109, 4.522, -0.702, 0.917, -0.153, 0.206, 0.474, 0.166,
            -2.615
        ],
        'Bendability (DNAse)': [
            -2.087, -1.509, -0.506, -2.126, 0.111, -0.121, -0.121, -1.354,
            0.381, 0.304, -0.313, -1.354, 1.615, -0.737, 1.229, -2.126, 0.265,
            0.496, 1.576, 1.229, -1.856, 0.072, -0.969, -0.313, 0.111, -0.468,
            -0.969, -0.121, 0.882, 0.419, 1.576, -0.506, -0.159, 0.034, 0.419,
            -0.737, 0.766, 1.036, -0.468, 0.304, 0.265, 1.036, 0.072, -0.121,
            0.342, 0.034, 0.496, -1.509, 0.689, 0.342, 0.882, 1.615, 1.73,
            0.265, 0.111, 0.381, 1.73, 0.766, -1.856, 0.111, 0.689, -0.159,
            0.265, -2.087
        ]
    }

    NN = 'ACGT'
    nn_name = [''.join(item) for item in itertools.product(NN, repeat=kmer_n)]

    nn_arr = KmerArray(sequence, kmer_n)
    count = Counter(nn_arr)

    for k in count:
        count[k] /= len(sequence) - kmer_n + 1

    theta_arr = []
    kmer_n_ = kmer_n
    for tmp_lamada in range(lamada_val):
        for p_name in property_name:
            theta = 0
            for i in range(len(sequence) - tmp_lamada - kmer_n_):
                theta += (property_value[p_name][nn_name.index(
                    sequence[i:i + kmer_n_])] *
                          property_value[p_name][nn_name.index(
                              sequence[i + tmp_lamada + 1:i + tmp_lamada + 1 +
                                       kmer_n_])])

            theta_arr.append(theta / (len(sequence) - tmp_lamada - kmer_n_))

    res = []
    for k in nn_name:
        res.append(count[k] / (1 + weight * sum(theta_arr)))

    start_i = len(NN)**kmer_n + 1
    for k in range(start_i, start_i + lamada_val * len(property_name)):
        res.append(
            (weight * theta_arr[k - start_i]) / (1 + weight * sum(theta_arr)))

    return res


class DNAFeature:

    def __init__(self):
        """
        18 DNA characterization methods
        """
        self.embedding_func = {
            'Kmer': Kmer,
            'RCKmer': RCKmer,
            'Mismatch': Mismatch,
            'CKSNAP': CKSNAP,
            'PseEIIP': PseEIIP,
            'DPCP': DPCP,
            'TPCP': TPCP,
            'MMI': MMI,
            'Z_curve_9bit': Z_curve_9bit,
            'Z_curve_12bit': Z_curve_12bit,
            'Z_curve_36bit': Z_curve_36bit,
            'Z_curve_48bit': Z_curve_48bit,
            'Z_curve_144bit': Z_curve_144bit,
            'PseDNC': PseDNC,
            'PseKNC': PseKNC,
            'PCPseTNC': PCPseTNC,
            'SCPseDNC': SCPseDNC,
            'SCPseTNC': SCPseTNC
        }
        print(f'Info: {len(self.embedding_func)} DNA characterization methods')

    def get_embedding_type(self):
        return list(self.embedding_func.keys())

    def display_example(self):
        seq = 'ATGGCAACGTCATGGTGCCGGGATTTTTGGCAGGCTTTTCGCCCTGGGATCCTACCGGGCAGCTTCCGAGGTGAGCTGGAAACCTTCCGTAAACTCGTCGAGCGCGACGCGCCGAGACGGGGCCTCGAGCACCACCACCACCACCACTGA'

        embedding_size = 0
        for i, embedding_name in enumerate(self.embedding_func.keys(), start=1):
            embedding = self.embedding_func[embedding_name](seq)
            embedding_size += len(embedding)

            print(f'{i}-{embedding_name}: {len(embedding)}')

        print('embedding_size', embedding_size, embedding_size - 768)

    def get_embedding(self, seq, feature_name):
        if feature_name not in self.embedding_func.keys():
            raise ValueError(f'No {feature_name} embedding type.')

        assert set('ACGT') == set(seq)

        return self.embedding_func[feature_name](seq)

    def init_feature_fit(self, seqs, embedding_fit_file):
        if os.path.exists(embedding_fit_file):
            # 控制更新某个
            with open(embedding_fit_file, 'rb') as w:
                embedding_fit = pickle.load(w)
        else:
            embedding_fit = {}

        for i, embedding_name in enumerate(self.embedding_func.keys(), start=1):
#             if embedding_name not in ['TPCP']:
#                 continue

            pbar = tqdm(total=len(seqs), desc=f'{i}-{embedding_name}')

            embeddings = []
            for seq in seqs:
                embeddings.append(self.embedding_func[embedding_name](seq))

                pbar.update(1)

            pbar.close()

            # 降维
            if embedding_name == 'TPCP':
                pca = PCA(n_components=105)
                embeddings = pca.fit_transform(embeddings)

                embedding_fit[f'{embedding_name} PCA'] = pca

            scaler = MinMaxScaler()
            scaler.fit(embeddings)

            embedding_fit[f'{embedding_name} Scaler'] = scaler

        with open(embedding_fit_file, 'wb') as w:
            pickle.dump(embedding_fit, w)

        print(
            f'Info: wtire scaler and pca ({len(embedding_fit)}) to {embedding_fit_file}'
        )

    def exec_seqs_embedding(self, seqs, names, embedding_fit_file,
                            embedding_res_file):
        with open(embedding_fit_file, 'rb') as w:
            embedding_fit = pickle.load(w)

        pbar = tqdm(total=len(seqs))

        seq_embeddings = {}
        for name, seq in zip(names, seqs):
            embeddings = []
            for embedding_name in self.embedding_func.keys():
                embedding = self.embedding_func[embedding_name](seq)

                if embedding_name == 'TPCP':
                    embedding = embedding_fit[
                        f'{embedding_name} PCA'].transform([embedding])

                    embedding = embedding_fit[
                        f'{embedding_name} Scaler'].transform(embedding)
                else:
                    embedding = embedding_fit[
                        f'{embedding_name} Scaler'].transform([embedding])

                # MinMaxScale 数值舍入误差导致的 会有负数
                embedding[embedding < 0] = 0

                embeddings += list(embedding[0])

            seq_embeddings[name] = {'embedding': embeddings}

            pbar.update(1)

        pbar.close()

        with open(embedding_res_file, 'wb') as w:
            pickle.dump(seq_embeddings, w)

        print(
            f'Info: feature embedding finish. num is ({len(seq_embeddings)}) to {embedding_res_file}'
        )

# df = DNAFeature()
# et = df.get_embedding_type()
# df.display_example()

In [None]:
def Fearture_embedding(dna_seq_csv, save_embedding_res_file):
    feature = DNAFeature()

    with open(dna_seq_csv, 'r', encoding='utf-8') as r:
            data = list(csv.reader(r))[1:]

    data = np.array(data)
    names = list(data[:, 0])
    seqs = list(data[:, 1])

    embedding_fit_file = f'{resource_dir}/feature-embedding-fit.pickle'

    feature.exec_seqs_embedding(seqs, names, embedding_fit_file, save_embedding_res_file)

In [None]:
if is_test:
    dna_file = f'{res_dir}/{protein_name}_scg_dnas_test.csv'

    # Feature engineering code  Result file
    dna_feature_embedding_pickle = f'{res_dir}/{protein_name}_scg_dnas_feature_embedding_test.pickle'

    Fearture_embedding(dna_file, dna_feature_embedding_pickle)

#     raise

### Transfer learning (DNABER-2) feature extraction

In [None]:
class SeqenceDataset(Dataset):
    def __init__(self, data_file: str):

        super(SeqenceDataset, self).__init__()

        # load data from the disk
        with open(data_file, "r") as f:
            data = list(csv.reader(f))[1:]

        self.names = [d[0] for d in data]
        self.labels = [d[2] for d in data]
        self.texts = [d[1] for d in data]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        return self.names[i], self.texts[i], self.labels[i]

In [None]:
def collate_fn(batch_samples, tokenizer):
    batch_names = []
    batch_texts = []
    batch_labels = []

    for sample in batch_samples:
        batch_names.append(sample[0])
        batch_texts.append(sample[1])
        batch_labels.append(int(sample[2]))

    # return_attention_mask：是否返回 attention mask。Attention mask 用于表示哪些位置在文本中是有效的，哪些是 padding
    #return_token_type_ids：是否返回 token type ids。Token type ids 用于区分文本中不同句子的标识
    X = tokenizer(
        batch_texts,
        padding=True,
        truncation=True,
        return_tensors="pt",
        return_attention_mask=True,
        return_token_type_ids=False
    )

    y = batch_labels

    return batch_names, X, y

In [None]:
#
from tqdm.auto import tqdm
import pickle
import numpy as np

# embedding
def dnabert_embedding(model, dataloader, scaler, embedding_data_file, device):
    gene_name_info, embedding_res_info, label_info = [], [], []

    embedding_data = {}

    bar = tqdm(total=len(dataloader))

    all_embedding = []
    for gene_names, batch_X, labels in dataloader:
        bar.update(1)

        # get input and batch_len
        batch_lens = [attention_mask.sum().item() for attention_mask in batch_X['attention_mask']]
        inputs = batch_X['input_ids'].to(device)

        with torch.no_grad():
            hidden_states = model(inputs)[0]

        # Generate per-sequence representations via averaging
#             batch_representations = []
        for i, seq_len in enumerate(batch_lens):
            representations_tensor = hidden_states[i, 1: seq_len-1].mean(0).cpu().numpy()
#                 batch_representations.append(representations_tensor.cpu().numpy())
            embedding_data[gene_names[i]] = {'embedding': representations_tensor, 'label': labels[i]}

            all_embedding.append(representations_tensor)

#             batch_representations = torch.stack(batch_representations, dim=0)

#             for index, gene_name in enumerate(gene_names):
#                 gene_name_info.append(gene_name)
#                 embedding_res_info.append(batch_representations[index])
#                 label_info.append(labels[index])

#         embedding_data = {'gene-name': gene_name_info, 'embedding': embedding_res_info, 'label': np.array(label_info)}

#         embedding_data_file = f'data/{tag}_dnabert_embedding.pkl'

    # fit
    for k in embedding_data:
        embedding_data[k]['embedding'] = scaler.transform([embedding_data[k]['embedding']])[0]

    with open(embedding_data_file, 'wb') as w:
        pickle.dump(embedding_data, w)

    bar.close()

    return embedding_data

In [None]:
 def DNABERT_2_Embedding_Func(dna_file, dna_dnabert_embedding_pickle):
    # tokenizer and dnabert-2 model
#     checkpoint = f'{resource_dir}/DNABERT-2-Finetune-1400'
    checkpoint = DNABERT_2_checkpoint
    print(checkpoint)

    model_max_length = 636

    tokenizer = AutoTokenizer.from_pretrained(checkpoint, model_max_length=model_max_length, trust_remote_code=True)
    dna_bert_2 = AutoModel.from_pretrained(checkpoint, trust_remote_code=True)

    # cuda
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Let's use, {device}!")

    dna_bert_2 = dna_bert_2.to(device)
    dna_bert_2.eval()

    # hyper-parameters
    batch_size = 16

    # dataset and dataloader
    dataset = SeqenceDataset(dna_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda x: collate_fn(x, tokenizer))

    #
    nesg_dnabert_scaler_file = f'{resource_dir}/nesg_dnabert_scaler.pickle'
    with open(nesg_dnabert_scaler_file, 'rb') as w:
        nesg_dnabert_scaler = pickle.load(w)

    scaler = nesg_dnabert_scaler['TT Scaler']

    # embedding
    return dnabert_embedding(dna_bert_2, dataloader, scaler, dna_dnabert_embedding_pickle, device)

In [None]:
if is_test:
    dna_file = f'{res_dir}/{protein_name}_scg_dnas_test.csv'

    # output
    dna_dnabert_embedding_pickle = f'{res_dir}/{protein_name}_scg_dnas_dnabert_embedding_test.pickle'

    dna_dnabert_embedding = DNABERT_2_Embedding_Func(dna_file, dna_dnabert_embedding_pickle)

    print(len(dna_dnabert_embedding))
    print(dna_dnabert_embedding.keys())

## Prediction of highly soluble expressed gene sequences using GELP models (SVM, LR, MLP, CNN-N-NF)

### SVM, LR, MLP

In [None]:
import pickle
import numpy as np

# 加载验证集数据
def load_merge_embedding(dnabert_e, feature_e, print_info=True):
    dnabert_embedding_data = None
    with open(dnabert_e, 'rb') as file:
        dnabert_embedding_data = pickle.load(file)

    feature_embedding_data = None
    with open(feature_e, 'rb') as file:
        feature_embedding_data = pickle.load(file)

    gene_name_info, seq_embedding_info, label_info = [], [], []

    for gene_name in dnabert_embedding_data.keys():
        data_de = dnabert_embedding_data[gene_name]
        data_fe = feature_embedding_data[gene_name]

        gene_name_info.append(gene_name)
        seq_embedding_info.append(np.concatenate([data_de['embedding'], data_fe['embedding']]))
        label_info.append(data_de['label'])

    if print_info:
        print(dnabert_e, feature_e)
        print(len(gene_name_info), len(seq_embedding_info), len(label_info))
        print(gene_name_info[:3], label_info[:3], len(
            seq_embedding_info[0]), seq_embedding_info[0])

    return gene_name_info, seq_embedding_info, label_info

In [None]:
import pickle

def Sklearn_model(dnabert_2_dataset, dna_feature_dataset):
    names, df_X, df_y = load_merge_embedding(dnabert_2_dataset, dna_feature_dataset)

    best_model_file = f'{resource_dir}/sklearn_best_model.pickle'
    with open(best_model_file, 'rb') as f:
        sklearn_best_model = pickle.load(f)

    res = {}
    for key in sklearn_best_model.keys():
        if 'Acc' in key:
            continue

        model = sklearn_best_model[key]
        print(key)

        y_ppba = model.predict_proba(df_X)
        y_score = y_ppba[:, 1]


        res_d = {}
        for output, name in zip(y_score, names):
            res_d[name] = output

        res_d = sorted(res_d.items(), key=lambda x: x[1], reverse=True)
#         print(res_d)

        res_d_n = {}
        for i, (k, v) in enumerate(res_d, start=1):
            res_d_n[k] = [v, i]

        res[key] = res_d_n

    return res

In [None]:
if is_test:
    dna_feature_embedding_pickle = f'{res_dir}/{protein_name}_scg_dnas_feature_embedding_test.pickle'
    dna_dnabert_embedding_pickle = f'{res_dir}/{protein_name}_scg_dnas_dnabert_embedding_test.pickle'

    res = Sklearn_model(dna_dnabert_embedding_pickle, dna_feature_embedding_pickle)
    print(res)

#     raise

### CNN-N-NF

In [None]:
nuclue_natural_number = {'AAA': 1, 'GAA': 2, 'GAT': 3, 'ATT': 4, 'AAT': 5, 'CTG': 6, 'TTT': 7, 'ATG': 8, 'TAT': 9, 'GCA': 10, 'CAG': 11, 'GGT': 12, 'GTT': 13, 'GAG': 14, 'AAC': 15, 'TTA': 16, 'GCT': 17, 'ATC': 18, 'GCC': 19, 'GGC': 20, 'GCG': 21, 'GAC': 22, 'AAG': 23, 'CTT': 24, 'ATA': 25, 'GTG': 26, 'CAA': 27, 'ACA': 28, 'ACC': 29, 'TCA': 30, 'TTC': 31, 'TGG': 32, 'GGA': 33, 'GTA': 34, 'TCT': 35, 'TTG': 36, 'ACT': 37, 'CGT': 38, 'AGC': 39, 'CAT': 40, 'AGT': 41, 'CGC': 42, 'TAC': 43, 'CCG': 44, 'GGG': 45, 'GTC': 46, 'ACG': 47, 'CCA': 48, 'CCT': 49, 'AGA': 50, 'CTC': 51, 'TCC': 52, 'CAC': 53, 'TCG': 54, 'CTA': 55, 'TGT': 56, 'CGG': 57, 'TGC': 58, 'CCC': 59, 'CGA': 60, 'AGG': 61, 'TAA': 62, 'TGA': 63, 'TAG': 64}

In [None]:
from torch.utils.data import Dataset
import csv

class CNNSequenceDataset(Dataset):
    def __init__(self, data_file: str, feature_file: str):

        super(CNNSequenceDataset, self).__init__()

        # load data from the disk
        with open(data_file, "r") as f:
            data = list(csv.reader(f))[1:]

        self.names = [d[0] for d in data]
        self.texts = [d[1] for d in data]
        self.labels = [int(d[2]) for d in data]

        with open(feature_file, 'rb') as file:
            self.feature_embedding_data = pickle.load(file)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, i):
        return self.names[i], self.texts[i], self.labels[i], self.feature_embedding_data[self.names[i]]['embedding']

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

import pickle

def encode(seq: str, encode_type):
    if encode_type is None:
        return seq

    # 'number-id'
    tokens =  [seq[i: i+3] for i in range(0, len(seq), 3)]
    nuclue_natural_number_idx = [nuclue_natural_number.get(token, 0) for token in tokens]

    if encode_type == 'One-hot':
        # return len * 65
        embedding = np.zeros([len(nuclue_natural_number_idx), len(nuclue_natural_number) + 1])

        for idx, number_id in enumerate(nuclue_natural_number_idx):
            embedding[idx][number_id] = 1
    else:
        # return list: 1 * len
        embedding = nuclue_natural_number_idx

    embedding = torch.tensor(embedding)

    return embedding

def cnn_collate_fn(batch_samples, encode_type=None):
    assert encode_type in [None, 'natural-number', 'One-hot']

    batch_names = []
    batch_texts = []
    batch_labels = []

    fea_embedding = []

    for sample in batch_samples:
        batch_names.append(sample[0])

        embedding = encode(sample[1], encode_type)

        batch_texts.append(embedding)
        batch_labels.append(sample[2])

        fea_embedding.append(sample[3])

    if encode_type is None:
        X = batch_texts
    else:
        # 填充句子到相同长度
        X = pad_sequence(batch_texts, batch_first=True, padding_value=0)

    y = torch.tensor(batch_labels)

    fea = torch.tensor(fea_embedding, dtype=torch.float32)

    return X, y, fea, batch_names


In [None]:
import torch.nn.functional as F

class CnnConfig:
    def __init__(self, inc, outc, ks, pks):
        # 输入数据高度，RGB 3
        self.in_channel = inc

        # 卷积核个数 16
        self.out_channel = outc

        # 卷积核大小 (3, 25)
        self.kernel_size = ks

        # 池化 (2, 2)
        self.pool_kernel_size = pks

class ModelConfig:
    def __init__(self, cnn_channel: list, first_kernel_list: list, other_kernel_size: list, encode_type: str, is_add_feature=False):
        self.encode_type = True if encode_type == 'natural-number' else False

        self.is_add_feature = is_add_feature

        # 嵌入参数: 输入 batch*seq_len*1 -> batch*seq_len*512
        self.num_embeddings = 65
        self.embedding_dim = cnn_channel[0]

        # 经过变换：-> batch*64*seq_len，输入通道都是

        # 拉链式 卷积参数
        self.module_list = []

#         first_kernel_list = [i for i in range(min_kernel, max_kernel)]
        first_kernel_list = first_kernel_list
        cnn_channel_list = cnn_channel

        other_kernel_size = other_kernel_size

        pool_kernel_size = 2

        for _kernel in first_kernel_list:
            line_cnn = []
            all_kernel_size = [_kernel] + other_kernel_size
            for idx in range(len(cnn_channel_list) - 1):
                line_cnn.append(CnnConfig(cnn_channel_list[idx],
                                          cnn_channel_list[idx + 1],
                                          all_kernel_size[idx],
                                          pool_kernel_size))
            self.module_list.append(line_cnn)

        # 全连接
        self.dropout_rate = 0.1
        self.in_features = cnn_channel_list[-1] * len(first_kernel_list)
        if self.is_add_feature:
            self.in_features += 1024

        self.num_class = 2

class KMersCNN(nn.Module):
    # https://blog.csdn.net/sunny_xsc1994/article/details/82969867
    def __init__(self, model_config: ModelConfig):
        super(KMersCNN, self).__init__()

        self.args = model_config

        self.use_embedding = self.args.encode_type
        self.is_add_feature = self.args.is_add_feature

        # https://www.jianshu.com/p/63e7acc5e890
        # 词嵌入: 词典大小为 25, 嵌入维度 64
        if self.use_embedding:
            self.embedding = nn.Embedding(num_embeddings=self.args.num_embeddings,
                                          embedding_dim=self.args.embedding_dim)

        # 卷积
        self.convs = nn.ModuleList()
        for sequential in self.args.module_list:
            layers = nn.Sequential()
            for cnn in sequential:
                # Conv1d 输入的形状通常为 (batch_size, input_channels, sequence_length)
                layers.append(nn.Conv1d(in_channels=cnn.in_channel,
                              out_channels=cnn.out_channel,
                              kernel_size=cnn.kernel_size))

                layers.append(nn.BatchNorm1d(cnn.out_channel))
                layers.append(nn.ReLU())
                layers.append(nn.Dropout1d(p=self.args.dropout_rate))
                layers.append(nn.MaxPool1d(kernel_size=cnn.pool_kernel_size))

            self.convs.append(layers)

        if self.is_add_feature:
            self.bn1, self.bn2 = nn.BatchNorm1d(self.args.in_features - 1024), nn.BatchNorm1d(self.args.in_features)
        else:
            self.bn1, self.bn2 = nn.BatchNorm1d(self.args.in_features), None

        # 全连接
        self.fc = nn.Linear(in_features=self.args.in_features,
                            out_features=self.args.num_class)

        # self.apply() 是 nn.Module 类的一个方法，它接受一个函数作为参数，并将该函数应用到模型的每个子模块
        self.apply(self.initialize_weights)

    def forward(self, x, batch_fea):
        x = x.to(torch.float32)
#         print(x.size())

        if self.use_embedding:
            x = x.to(torch.long)
            x = self.embedding(x)
#             print(x.size())

        x = x.permute(0, 2, 1)
#         print(x.size())

        outs = []
        for conv in self.convs:
#             print(x.dtype, conv)
            out = conv(x)
#             print(out.size(), out.size(-1))

            # 手动获取最后一维大小，并做 MaxPool
            max_pool = nn.MaxPool1d(out.size(-1))
            out = max_pool(out)
#             print(out.size())

            outs.append(out)

#         out = [conv(x) for conv in self.convs]
#         for item in out:
#             print(item.size())

        out = torch.cat(outs, dim=1)
#         print(out.size())

        out = out.view(-1, out.size(1))
#         print(out.size())

        out = self.bn1(out)

        if self.is_add_feature:
            out = torch.cat([out, batch_fea], dim=1)

            out = self.bn2(out)

#         raise
        out = F.dropout(input=out, p=self.args.dropout_rate)
        out = self.fc(out)

        return out

    def initialize_weights(self, m):
        if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)


In [None]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef

def predict(model, dataloader, device, use_tqdm=False):
    if use_tqdm:
        progress_bar = tqdm(range(len(dataloader)))

    # 计算指标
    model_output = []
    batch_names = []

    model.eval()
    for batch_X, batch_y, batch_fea, batch_name in dataloader:
        batch_names += batch_name

        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)
        batch_fea = batch_fea.to(device)

        pred = model(batch_X, batch_fea)

        model_output.append(F.softmax(pred).cpu().detach().numpy())

        if use_tqdm:
            progress_bar.update(1)

    # 模型的输出概率值
    model_output = np.concatenate(model_output, axis =0)

    res_d = {}
    for output, name in zip(model_output[:, 1], batch_names):
        res_d[name] = output

    res_d = sorted(res_d.items(), key=lambda x: x[1], reverse=True)

    res_d_n = {}
    for i, (k, v) in enumerate(res_d, start=1):
        res_d_n[k] = [v, i]

    return res_d_n

In [None]:
# main
import sys
import torch

def CNN_NNF(data_csv, dna_feature_embedding):
    # 1 file

    # 2 data dataset dataloader
    batch_size = 16
    encode_type = 'natural-number'
    # encode_type = 'One-hot'

    val_dataset = CNNSequenceDataset(data_csv, dna_feature_embedding)

    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: cnn_collate_fn(x, encode_type))

    # 3 model 1
    min_kernel = 1
    max_kernel = 16
    cnn_channel = [512, 512, 512]
    first_kernel_list = [i for i in range(min_kernel, max_kernel)]
    other_kernel_size = [2]

    is_add_fea = False

    def get_list_info(_list):
        _str = ''
        for _item in _list:
            _str = _str + '-' + str(_item)
        return _str[1:]

    tag = f'nesg-cnn-{encode_type}-{get_list_info(cnn_channel)}-{min_kernel}-{max_kernel}-{get_list_info(other_kernel_size)}-{is_add_fea}'

    checkpoint_path = f'{resource_dir}/{tag}.pt'
    print(checkpoint_path)

    model_config = ModelConfig(cnn_channel, first_kernel_list, other_kernel_size, encode_type, is_add_feature=is_add_fea)
    model = KMersCNN(model_config)

    checkpoint = torch.load(checkpoint_path)

    model.load_state_dict(checkpoint['model_state_dict'])

    curr_epoch = checkpoint['epoch'] + 1
    bast_acc = checkpoint['acc']

    print(f'resume: epoch {curr_epoch}, acc {bast_acc}')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = model.to(device)

    pre_res = predict(model, val_dataloader, device)

    return {
        'CNN-NNF': pre_res
    }

In [None]:
if is_test:
    dna_file = f'{res_dir}/{protein_name}_scg_dnas_test.csv'
    dna_feature_embedding_pickle = f'{res_dir}/{protein_name}_scg_dnas_feature_embedding_test.pickle'

    res = CNN_NNF(dna_file, dna_feature_embedding_pickle)
    print(res)

    raise

## Screening results

### Run

In [None]:
def run(predict_data_dir, protein_name, protein_seq, dna_seq_now):
    # files
    print(f'Start: {protein_name}')

    # 保存原始的 DNA 序列， 用于特征提取
    dna_file = f'{predict_data_dir}/{protein_name}_scg_dnas.csv'

    # DNABERT-2 提取特征 结果文件
    dna_dnabert_embedding_pickle = f'{predict_data_dir}/{protein_name}_scg_dnas_dnabert_embedding.pickle'

    # 特征工程编码 结果文件
    dna_feature_embedding_pickle = f'{predict_data_dir}/{protein_name}_scg_dna_feature_embedding.pickle'

    if exec_func['Exec SCG Model']:
#         beam_batch_sizes, beam_widths = [1, 2, 4, 8, 16, 32], [2, 2, 3, 4, 5, 5]
#         beam_sizes = [val * 20 for val in beam_batch_sizes]

        dnas = SCG_Transformer_predict(protein_seq, beam_batch_sizes, beam_widths, beam_sizes)

        # 用于 DNABERT-2 提取特征, 预测及原始的 DNA 序列
        with open(dna_file, 'w', encoding='utf-8') as w:
            w.write('name,label,nucle-seq\n')
            for i, dna in enumerate(dnas):
                w.write(f'{protein_name}_scg_{i},{dna},-1\n')

            # 保存已有的
            for k in dna_seq_now:
                w.write(f'{k},{dna_seq_now[k]},-1\n')

    print()
    if exec_func['Exec DNA Feature']:
        Fearture_embedding(dna_file, dna_feature_embedding_pickle)

    print()
    if exec_func['Exec DNABERT-2 Model']:
        # DNABERT-2 提取特征
        dna_dnabert_embedding = DNABERT_2_Embedding_Func(dna_file, dna_dnabert_embedding_pickle)

        print(len(dna_dnabert_embedding))
        print(list(dna_dnabert_embedding.keys())[-10:])

    print()
    if exec_func['Exec Sklearn model']:
        sk_res = Sklearn_model(dna_dnabert_embedding_pickle, dna_feature_embedding_pickle)
#         print(res)

    print()
    if exec_func['Exec CNN-N-NF model']:
        cnn_res = CNN_NNF(dna_file, dna_feature_embedding_pickle)
#         print(res2)

    sk_res.update(cnn_res)

    # save to pickle
    predict_res_pickle = f'{predict_data_dir}/{protein_name}_predict_res.pickle'
    with open(predict_res_pickle, 'wb') as w:
        pickle.dump(sk_res, w)

    print(f'Info: {protein_name} predict res save to {predict_res_pickle}')

    print()

In [None]:
run(res_dir, protein_name, protein_seq, dnas)

Start: SSS
modle: /content/SCG-GELP/tf_0.0001_512_8_3_3_512_0.1_1024_6.pt
device: cuda
resume train: epoch 47, loss 0.9707116135501553
12707396
None
beam search args: [4] [5] [8]
args: 4 5 8


  0%|          | 0/239 [00:00<?, ?it/s]

SCG return dna num: 8
Info: 18 DNA characterization methods


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


  0%|          | 0/10 [00:00<?, ?it/s]

Info: feature embedding finish. num is (10) to scg-gelp-res/SSS_scg_dna_feature_embedding.pickle
/content/DNABERT-2-Finetune-1400
Let's use, cuda!


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


  0%|          | 0/1 [00:00<?, ?it/s]

10
['SSS_scg_0', 'SSS_scg_1', 'SSS_scg_2', 'SSS_scg_3', 'SSS_scg_4', 'SSS_scg_5', 'SSS_scg_6', 'SSS_scg_7', 'SSS-WT', 'SSS-GS']
scg-gelp-res/SSS_scg_dnas_dnabert_embedding.pickle scg-gelp-res/SSS_scg_dna_feature_embedding.pickle
10 10 10
['SSS_scg_0', 'SSS_scg_1', 'SSS_scg_2'] [-1, -1, -1] 1792 [0.2559305  0.13096302 0.89234928 ... 0.21708597 0.39477342 0.44684992]
SVM
LR
MLP


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


/content/SCG-GELP/nesg-cnn-natural-number-512-512-512-1-16-2-False.pt
resume: epoch 60, acc 0.7333333333333333
Info: SSS predict res save to scg-gelp-res/SSS_predict_res.pickle



  model_output.append(F.softmax(pred).cpu().detach().numpy())


In [None]:
### Screening results

In [None]:
predict_res_pickle = f'{res_dir}/{protein_name}_predict_res.pickle'
with open(predict_res_pickle, 'rb') as r:
    res = pickle.load(r)

model_names = list(res.keys())

# 取第二个模型的 排名
seq_names = list(res[model_names[model_id]].keys())

print(model_names, seq_names[:3], res[model_names[model_id]][seq_names[0]])

# save all data to csv
gs_acc, gs_rank = 0, 0
res_file = f'{res_dir}/{protein_name}_predict_res.csv'
with open(res_file, 'w', encoding='utf-8') as w:
    head_names_str = ''
    for model_name in model_names:
        head_names_str += f'{model_name} Acc,{model_name} Rank,'

    w.write(f'name,{head_names_str[:-1]},Total Acc,Total Rank,info\n')

    for seq_name in seq_names:
        acc_rank_str = ''
        acc_tol, rank_tol = 0, 0

        for model_name in model_names:
            item = res[model_name][seq_name]
            acc_rank_str += f'{item[0]},{item[1]},'

            acc_tol += item[0]
            rank_tol += item[1]

        w.write(f'{seq_name},{acc_rank_str[:-1]},{acc_tol/len(model_names)},{rank_tol/len(model_names)}\n')

print(f'Info: {protein_name} predict res save to {res_file}')

['SVM', 'LR', 'MLP', 'CNN-NNF'] ['SSS-WT', 'SSS-GS', 'SSS_scg_7'] [0.09325464500217194, 1]
Info: SSS predict res save to scg-gelp-res/SSS_predict_res.csv
