In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn

from transformers import EsmTokenizer, EsmModel

In [2]:
from utils import model

device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')

# Load LLAMP model

In [3]:
LLAMP_model = model.LLAMP(hidden_feat = 256, pooling = 'CLS')

Some weights of the model checkpoint at Daehun/peptide_tuned_ESM-2 were not used when initializing EsmModel: ['lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'esm.contact_head.regression.bias', 'esm.contact_head.regression.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at Daehun/peptide_tuned_ESM-2 and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on

In [4]:
LLAMP_model.load_state_dict(torch.load('model_weight/LLAMP.pth'), strict=False)
LLAMP_model.to(device)

LLAMP(
  (bert): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0): EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-05, elementwise_affine=True)
  

# Load Dataset

In [5]:
from utils import dataset

In [6]:
test_df = pd.read_csv('data/DBAASP/test.csv')
test_df

Unnamed: 0,sequence,label,species
0,AGRQTIAKYLRREIRKRGRKWVIAW,0.000000,Escherichia coli
1,KIAGKIAAIAGKIAKIAGAIAKIAGKIA,0.482921,Escherichia coli
2,FLPGLECVSGKIVPTVFCAITRIC,0.766642,Escherichia coli
3,IRPIIRPIIRPIIRPI,1.204120,Escherichia coli
4,TPFLLVGTQIDLR,0.380211,Escherichia coli
...,...,...,...
3736,GRLRNLIEKAGQNIRGKIQGIGRRIKDILKNLQPRPQV,0.634700,Staphylococcus haemolyticus
3737,FQRYFHRYARFLAKIWKG,1.000000,Staphylococcus haemolyticus
3738,YKRWKKWRSKAKKIL,0.296820,Staphylococcus haemolyticus
3739,IGRHFKRRNSIWGICWF,1.170262,Staphylococcus haemolyticus


In [7]:
genome_feat_dict = torch.load('data/Genomic_featrues/genome_features.pt')

seqs = list(test_df.sequence.astype(str))

genome_feats = dataset.get_features(test_df['species'], genome_feat_dict)

In [8]:
test_dataloader = dataset.data_loader(seqs, genome_feats, torch.as_tensor(test_df['label'], dtype=torch.float32), BATCH_SIZE = 32)



# Check test performance

In [9]:
from utils import utils

In [10]:
labels = []
preds = []
with torch.no_grad():
    for input_ids, attention_mask, genome_feat, label in (test_dataloader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        genome_feat = genome_feat.to(device)
        label = label.to(device)
        
        output = LLAMP_model(input_ids, attention_mask, genome_feat)
        
        labels.append(label)
        preds.append(output)

    labels = torch.cat(labels)
    labels = labels.cpu().numpy()

    preds = torch.cat(preds)
    preds = preds.cpu().numpy()

In [11]:
utils.compute_metrics(preds, labels)

{'R_squre': 0.5364365222074534,
 'mae': 0.37941855,
 'mse': 0.27161157,
 'rmse': 0.52116364,
 'Pearson Q': 0.7346180686280134}

# Inference for seq about 45 species

In [12]:
genome_feat_dict.keys()

dict_keys(['Escherichia coli', 'Staphylococcus aureus', 'Pseudomonas aeruginosa', 'Salmonella enterica', 'Bacillus subtilis', 'Klebsiella pneumoniae', 'Staphylococcus epidermidis', 'Acinetobacter baumannii', 'Enterococcus faecalis', 'Micrococcus luteus', 'Listeria monocytogenes', 'Bacillus cereus', 'Enterococcus faecium', 'Enterobacter cloacae', 'Streptococcus mutans', 'Streptococcus pyogenes', 'Bacillus megaterium', 'Pseudomonas syringae', 'Streptococcus pneumoniae', 'Proteus mirabilis', 'Klebsiella aerogenes', 'Stenotrophomonas maltophilia', 'Serratia marcescens', 'Shigella dysenteriae', 'Proteus vulgaris', 'Streptococcus agalactiae', 'Aeromonas salmonicida', 'Vibrio parahaemolyticus', 'Listeria innocua', 'Aeromonas hydrophila', 'Pasteurella multocida', 'Klebsiella oxytoca', 'Lactococcus lactis', 'Pectobacterium carotovorum', 'Staphylococcus haemolyticus', 'Vibrio alginolyticus', 'Pseudomonas putida', 'Shigella flexneri', 'Staphylococcus xylosus', 'Vibrio anguillarum', 'Corynebacteri

In [13]:
sequence = 'SSSSSSAAAAARRRRRRRGGGGGGGG'
species = 'Escherichia coli'

In [14]:
tokenizer = EsmTokenizer.from_pretrained('Daehun/peptide_tuned_ESM-2')

def get_inputs(seq, species):
    inputs = tokenizer.batch_encode_plus([seq])
    
    input_id = torch.tensor(inputs['input_ids'])
    attention_mask = torch.tensor(inputs['attention_mask'])
    
    genome_feat = torch.as_tensor(genome_feat_dict[species][0], dtype=torch.float32).unsqueeze(0)
    
    return input_id, attention_mask, genome_feat

input_id, attention_mask, genome_feat = get_inputs(sequence, species)

input_id = input_id.to(device)
attention_mask = attention_mask.to(device)
genome_feat = genome_feat.to(device)

In [15]:
with torch.no_grad():
    output = LLAMP_model(input_id, attention_mask, genome_feat)
output = output.cpu().numpy()

In [16]:
print(f"MIC (log uM) about {species} : {output}")
print(f"MIC (uM) about {species} : {10**output}")

MIC (log uM) about Escherichia coli : 1.8840848207473755
MIC (uM) about Escherichia coli : 76.57461478816683
