In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import numpy as np
from difflib import SequenceMatcher
import itertools

In [None]:
X_test  = np.load("/content/drive/MyDrive/X_test.npy")
Y_test  = np.load("/content/drive/MyDrive/Y_test.npy")

In [None]:
nucleotides = ["A", "T", "C", "G"]
codons = ["".join(p) for p in itertools.product(nucleotides, repeat=3)]
codon_vocab = {codon: idx for idx, codon in enumerate(codons)}
codon_vocab["<PAD>"] = 64
codon_vocab["<SOS>"] = 65
codon_vocab["<EOS>"] = 66
id_to_codon = {v: k for k, v in codon_vocab.items()}

In [None]:
class ProteinToDNATransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_embed = nn.Embedding(22, 128)  # 20 amino acids + PAD + SOS
        self.decoder_embed = nn.Embedding(67, 128)  # 64 codons + PAD + SOS + EOS
        self.transformer = nn.Transformer(
            d_model=128, nhead=8, num_encoder_layers=4,
            num_decoder_layers=4, batch_first=True
        )
        self.fc_out = nn.Linear(128, 67)

    def forward(self, src, tgt):
        src_embed = self.encoder_embed(src)
        tgt_embed = self.decoder_embed(tgt)
        return self.fc_out(self.transformer(src_embed, tgt_embed))

In [None]:
model = ProteinToDNATransformer()
model.load_state_dict(torch.load("/content/drive/MyDrive/models/best_transformer_with_sos.pt", map_location=torch.device("cpu")))
model.eval()

ProteinToDNATransformer(
  (encoder_embed): Embedding(22, 128)
  (decoder_embed): Embedding(67, 128)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(

In [None]:
def beam_search_decode(model, src, beam_width=3, max_len=1160, start_token=65, pad_token=64):
    device = next(model.parameters()).device
    src = src.to(device)
    sequences = [[torch.tensor([start_token], device=device), 0.0]]  # [sequence, score]

    with torch.no_grad():
        for _ in range(max_len):
            all_candidates = []
            for seq, score in sequences:
                tgt_input = seq.unsqueeze(0)
                out = model(src, tgt_input)
                logits = out[:, -1, :]
                probs = torch.nn.functional.log_softmax(logits, dim=-1)

                topk_probs, topk_ids = torch.topk(probs, beam_width)

                for i in range(beam_width):
                    token = topk_ids[0][i].item()
                    new_score = score + topk_probs[0][i].item()
                    new_seq = torch.cat([seq, torch.tensor([token], device=device)])
                    if token == pad_token:
                        all_candidates.append([new_seq, new_score])
                        continue
                    all_candidates.append([new_seq, new_score])

            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]

        return sequences[0][0]

In [None]:
def codon_match(real, pred):
    matches = sum(1 for a, b in zip(real, pred) if a == b)
    return (matches / min(len(real), len(pred))) * 100


In [None]:
def sequence_similarity(a, b):
    return SequenceMatcher(None, a, b).ratio() * 100

In [None]:
num_samples = 2
total_codon_match, total_similarity = 0, 0


In [None]:
for i in range(num_samples):
    src = torch.LongTensor(X_test[i]).unsqueeze(0)
    pred_seq = beam_search_decode(model, src, beam_width=3)

    real_codons = [id_to_codon[id] for id in Y_test[i] if id != 64]
    pred_codons = [id_to_codon[id.item()] for id in pred_seq if id.item() != 64 and id.item() in id_to_codon]

    cm = codon_match(real_codons, pred_codons)
    sim = sequence_similarity("".join(real_codons), "".join(pred_codons))

    total_codon_match += cm
    total_similarity += sim

    print(f"\nSample {i+1}")
    print("Real:     ", " ".join(real_codons[:15]))
    print("Predicted:", " ".join(pred_codons[:15]))
    print(f"Codon Match: {cm:.2f}% | Sequence Similarity: {sim:.2f}%")


Sample 1
Real:      ATG GCC ATC CCT GCT TTT GGT TTA GGT ACT TTT AGG CTA AAG GAC
Predicted: <SOS>
Codon Match: 0.00% | Sequence Similarity: 0.00%

Sample 2
Real:      ATG CCA GGG AAT CGC CCA CAC TAT GGG CGG TGG CCG CAG CAC GAT
Predicted: <SOS> ATG AAA AAA CGT TCG AAA AAA AAA CGT TCG AAA AAA AAA CGT
Codon Match: 1.80% | Sequence Similarity: 0.00%


In [None]:
dengichko ra!!