In [103]:
import pandas as pd
import numpy as np
from Bio import SeqIO
import sys
import json
import torch
import re
import os
import h5py
import lightning.pytorch as pl
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
warnings.filterwarnings('ignore')
np.random.seed(42)

### Preparation

In [104]:
def retrieve_json(path):
    with open(path, 'r') as fp:
        data = json.load(fp)
    return data


In [105]:
class_encode = retrieve_json("label_encode/class_encode-11.json") # load the 11 class encode (use others depending on the model for inference)
N_METALS = len(class_encode)

In [106]:
# load pLM for embedding
import ankh
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model, tokenizer = ankh.load_large_model()
model.eval()
model.to(device=device)

Some weights of the model checkpoint at ElnaggarLab/ankh-large were not used when initializing T5EncoderModel: ['decoder.block.19.layer.0.SelfAttention.k.weight', 'decoder.block.22.layer.0.layer_norm.weight', 'decoder.block.12.layer.0.SelfAttention.v.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.0.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'decoder.block.20.layer.0.SelfAttention.q.weight', 'decoder.block.13.layer.0.SelfAttention.q.weight', 'decoder.block.17.layer.1.EncDecAttention.k.weight', 'decoder.block.12.layer.0.SelfAttention.k.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.19.layer.2.DenseReluDense.wo.weight', 'decoder.block.17.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.18.layer.2.DenseReluDense.wo.weight', 'decoder.block.17.layer.0.layer_norm.weight', 'decoder.block.9.layer.2.lay

T5EncoderModel(
  (shared): Embedding(144, 1536)
  (encoder): T5Stack(
    (embed_tokens): Embedding(144, 1536)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1536, out_features=1024, bias=False)
              (k): Linear(in_features=1536, out_features=1024, bias=False)
              (v): Linear(in_features=1536, out_features=1024, bias=False)
              (o): Linear(in_features=1024, out_features=1536, bias=False)
              (relative_attention_bias): Embedding(64, 16)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=1536, out_features=3840, bias=False)
              (wi_1): Linear(in_features=1536, out_features=3840, bias=False)
              (wo): Lin

In [107]:


class Self_Attention(nn.Module):
    def __init__(self, num_hidden, num_heads=4, weight_matrix = False):
        super().__init__()
        self.num_heads = num_heads
        self.attention_head_size = int(num_hidden / num_heads)
        self.all_head_size = self.num_heads * self.attention_head_size
        self.weight_matrix = weight_matrix

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_heads,
                                       self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, q, k, v, mask=None):

        q = self.transpose_for_scores(q) 
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)

        attention_scores = torch.matmul(q, k.transpose(-1, -2))

        if mask is not None:
            attention_mask = (1.0 - mask) * -10000
            attention_scores = attention_scores + \
                attention_mask.unsqueeze(1).unsqueeze(1)

        attention_scores = nn.Softmax(dim=-1)(attention_scores)

        outputs = torch.matmul(attention_scores, v)

        outputs = outputs.permute(0, 2, 1, 3).contiguous()
        new_output_shape = outputs.size()[:-2] + (self.all_head_size,)
        outputs = outputs.view(*new_output_shape)
        return outputs

class PositionWiseFeedForward(nn.Module):
    def __init__(self, num_hidden, num_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.W_in = nn.Linear(num_hidden, num_ff, bias=True)
        self.W_out = nn.Linear(num_ff, num_hidden, bias=True)

    def forward(self, h_V):
        h = F.leaky_relu(self.W_in(h_V))
        h = self.W_out(h)
        return h


class TransformerLayer(nn.Module):
    def __init__(self, num_hidden=64, num_heads=4, dropout=0.2):
        super(TransformerLayer, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.ModuleList(
            [nn.LayerNorm(num_hidden, eps=1e-6) for _ in range(2)])

        self.attention = Self_Attention(num_hidden, num_heads)
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)

    def forward(self, h_V, mask=None):
        dh = self.attention(h_V, h_V, h_V, mask)
        h_V = self.norm[0](h_V + self.dropout(dh))
        dh = self.dense(h_V)
        h_V = self.norm[1](h_V + self.dropout(dh))

        if mask is not None:
            mask = mask.unsqueeze(-1)
            h_V = mask * h_V
        return h_V


class MetalBPredictor(nn.Module):
    def __init__(self, feature_dim, hidden_dim=64, num_encoder_layers=2, num_heads=4, dropout=0.2):
        super(MetalBPredictor, self).__init__()
        self.input_block = nn.Sequential(
            nn.LayerNorm(feature_dim, eps=1e-6), nn.Linear(feature_dim,
                                                           hidden_dim), nn.LeakyReLU()
        )

        self.hidden_block = nn.Sequential(
            nn.LayerNorm(hidden_dim, eps=1e-6), nn.Dropout(dropout), nn.Linear(
                hidden_dim, hidden_dim), nn.LeakyReLU(), nn.LayerNorm(hidden_dim, eps=1e-6)
        )
        self.encoder_layers = nn.ModuleList([
            TransformerLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.dense = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(hidden_dim, N_METALS)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, protein_feat, mask):
        h_V = self.input_block(protein_feat)
        h_V = self.hidden_block(h_V)

        for layer in self.encoder_layers:
            h_V = layer(h_V, mask)
            
        x = self.dense(h_V)
        x = self.dropout(x)
        logits = self.out_proj(x)
        logits = torch.flatten(logits, end_dim=1)
        return logits


In [108]:
class TransformerModel(pl.LightningModule):
    def __init__(self, train_pos, train_neg, feature_dim, hidden_dim=64, num_encoder_layers=2, num_heads=4, dropout=0.2, lr=1e-3, label_weight=[0.228, 5.802], batch_size=32, thres_tune=False):
        super().__init__()
        self.encoder = MetalBPredictor(
            feature_dim, hidden_dim, num_encoder_layers, num_heads, dropout)
        self.save_hyperparameters()

        self.val_loss = 0
        self.test_loss = 0
        self.learning_rate = lr
        self.label_weight = label_weight
        self.batch_size = batch_size
        self.train_pos = train_pos
        self.train_neg = train_neg
        self.val_y = []
        self.val_pred = []
        self.test_y = []
        self.test_pred = []
        self.thres_tune = thres_tune

    def forward(self, x, mask):
        x = self.encoder(x, mask)
        return torch.squeeze(x)


In [110]:
def inference(seq_fragment, thres):
    seq_fragment = re.sub(r"[UZOB]", "X", seq_fragment)
    seq = list(seq_fragment)
    with torch.no_grad():
        
        output = tokenizer.batch_encode_plus([seq],
                                                    add_special_tokens=True,
                                                    padding=True,
                                                    is_split_into_words=True,
                                                    return_tensors="pt")
        embedding = model(input_ids=output['input_ids'].to(
            device=device)).last_hidden_state
        embedding = embedding[0][0:-1].unsqueeze(0)

    predictor = TransformerModel.load_from_checkpoint(checkpoint_path=r"models\TFE-11\checkpoints\epoch=79-val_loss=0.046311-MCC=0.593-AUPR=0.614.ckpt", map_location=None)
    
    predictor = predictor.to(device)
    predictor.eval()

    with torch.no_grad():
        y_hat = predictor(embedding, None)
        
    prob = torch.sigmoid(y_hat).detach().cpu().numpy()
    binary = np.vstack([(prob[:, i] > thres[i]) for i in range(N_METALS)]).astype(np.int8).transpose()
    return binary

In [111]:
def predict_metal_binding_site(seq, thres):
    if len(seq) > 512:
        seqs = [seq[i:i+512] for i in range(0, len(seq), 512)]
    else:
        seqs = [seq]
    outputs = [inference(seq, thres) for seq in seqs]
    return np.vstack(outputs)

In [123]:
def interpreter(outputs):
    for i, j in zip(np.where(outputs==1)[0], np.where(outputs==1)[1]):
        print(f"{class_encode[str(j)][1]} at position {i+1}")

### Inference

In [124]:

# specify the threshold for each metal
# specify the sequence

thres = [0.99, 0.950, 0.980, 0.990, 0.980, 0.990, 0.990, 0.950, 0.730, 0.980, 0.990]
seq = "MMKFSVIVPTYNSEKYITELLNSLAKQDFPKTEFEVVVVDDCSTDQTLQIVEKYRNKLNLKVSQLETNSGGPGKPRNVALKQAEGEFVLFVDSDDYINKETLKDAAAFIDEHHSDVLLIKMKGVNGRGVPQSMFKETAPEVTLLNSRIIYTLSPTKIYRTALLKDNDIYFPEELKSAEDQLFTMKAYLNANRISVLSDKAYYYATKREGEHMSSAYVSPEDFYEVMRLIAVEILNADLEEAHKDQILAEFLNRHFSFSRTNGFSLKVKLEEQPQWINALGDFIQAVPERVDALVMSKLRPLLHYARAKDIDNYRTVEESYRQGQYYRFDIVDGKLNIQFNEGEPYFEGIDIAKPKVKMTAFKFDNHKIVTELTLNEFMIGEGHYDVRLKLHSRNKKHTMYVPLSVNANKQYRFNIMLEDIKAYLPKEKIWDVFLEVQIGTEVFEVRVGNQRNKYAYTAETSALIHLNNDFYRLTPYFTKDFNNISLYFTAITLTDSISMKLKGKNKIILTGLDRGYVFEEGMASVVLKDDMIMGMLSQTSENEVEILLSKDIKKRDFKNIVKLNTAHMTYSLK"

In [125]:
outputs = predict_metal_binding_site(seq, thres)
interpreter(outputs)

Mn(2+) at position 94
a divalent metal cation at position 94
a metal cation at position 94
