In [1]:
import time
from pathlib import Path
from urllib import request
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5EncoderModel, T5Tokenizer

# Device setup
if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print("Using device: {}".format(device))

# CNN Model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.classifier = nn.Sequential(
            nn.Conv2d(1024, 32, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
            nn.ReLU(),
            nn.Dropout(0.0),
            nn.Conv2d(32, 20, kernel_size=(7, 1), padding=(3, 0))
        )

    def forward(self, x):
        x = x.permute(0, 2, 1).unsqueeze(dim=-1)  # Reshape for convolution
        Yhat = self.classifier(x)
        Yhat = Yhat.squeeze(dim=-1)
        return Yhat

# Functions for loading models and processing
def get_T5_model():
    print("Loading T5 model...")
    model = T5EncoderModel.from_pretrained("Rostlab/ProstT5_fp16").to(device)
    vocab = T5Tokenizer.from_pretrained("Rostlab/ProstT5_fp16")
    return model.eval(), vocab

def load_predictor():
    print("Loading CNN predictor...")
    model = CNN()
    checkpoint_path = Path("AA_to_3Di_prostt5_cnn_model.pt")
    if not checkpoint_path.exists():
        raise FileNotFoundError("Model checkpoint not found. Please provide the file.")
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state["state_dict"])
    return model.eval().to(device)

# Helper to process sequence
def process_sequence(sequence, t5_model, tokenizer, cnn_model):
    print(f"Processing sequence: {sequence}")
    prefix = "<AA2fold>"
    seq = prefix + ' ' + ' '.join(list(sequence))
    token_encoding = tokenizer(seq, return_tensors="pt").to(device)
    with torch.no_grad():
        embedding_repr = t5_model(**token_encoding)
        embedding = embedding_repr.last_hidden_state[:, 1:, :]  # Skip special token
        prediction = cnn_model(embedding)
        prediction = prediction.argmax(dim=1).squeeze().cpu().numpy()
    return prediction

# Run an example sequence
t5_model, tokenizer = get_T5_model()
cnn_model = load_predictor()

# Example amino acid sequence
example_sequence = "ACDEFGHIKLMNPQRSTVWY"
predicted_3Di = process_sequence(example_sequence, t5_model, tokenizer, cnn_model)

# Map predictions to symbols
ss_mapping = {
    0: "A", 1: "C", 2: "D", 3: "E", 4: "F", 5: "G", 6: "H", 7: "I",
    8: "K", 9: "L", 10: "M", 11: "N", 12: "P", 13: "Q", 14: "R", 15: "S",
    16: "T", 17: "V", 18: "W", 19: "Y"
}
predicted_3Di_symbols = "".join([ss_mapping[p] for p in predicted_3Di])

print(f"Original Sequence: {example_sequence}")
print(f"Predicted 3Di Sequence: {predicted_3Di_symbols}")


Using device: cuda:0
Loading T5 model...


Some weights of the model checkpoint at Rostlab/ProstT5_fp16 were not used when initializing T5EncoderModel: ['decoder.block.7.layer.1.EncDecAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'decoder.block.17.layer.0.SelfAttention.q.weight', 'decoder.block.19.layer.1.EncDecAttention.k.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.22.layer.0.SelfAttention.v.weight', 'decoder.block.13.layer.1.layer_norm.weight', 'decoder.block.21.layer.1.EncDecAttention.q.weight', 'decoder.block.12.layer.2.DenseReluDense.wi.weight', 'decoder.block.16.layer.0.layer_norm.weight', 'decoder.block.20.layer.0.SelfAttention.k.weight', 'decoder.block.16.layer.0.SelfAttention.o.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.1.layer_norm.weight', 'decoder.block.9.layer.2.DenseReluDense.wi.weight', 'decoder.block.10.layer.0.layer_norm.weight', 'decoder.block.22.layer.0.SelfAtte

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/283 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.40k [00:00<?, ?B/s]

Loading CNN predictor...
Processing sequence: ACDEFGHIKLMNPQRSTVWY
Original Sequence: ACDEFGHIKLMNPQRSTVWY
Predicted 3Di Sequence: DDDPVVDDDDDDPVPDDDIDD
