In [None]:
from google.colab import drive
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
import matplotlib.pyplot as plt

In [None]:
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
MAX_LEN = 1160
AMINO_VOCAB_SIZE = 22     # 20 amino acids + <PAD> + <SOS>
CODON_VOCAB_SIZE = 67     # 64 codons + <PAD> + <SOS> + <EOS>
EMBED_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 4
BATCH_SIZE = 16
EPOCHS = 300
PATIENCE = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
X_train = np.load("/content/drive/MyDrive/X_train_t.npy")
Y_train = np.load("/content/drive/MyDrive/Y_train_t.npy")
X_test  = np.load("/content/drive/MyDrive/X_test_t.npy")
Y_test  = np.load("/content/drive/MyDrive/Y_test_t.npy")

In [None]:
train_data = TensorDataset(torch.LongTensor(X_train), torch.LongTensor(Y_train))
test_data = TensorDataset(torch.LongTensor(X_test), torch.LongTensor(Y_test))


In [None]:
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

In [None]:
class ProteinToDNATransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_embed = nn.Embedding(AMINO_VOCAB_SIZE, EMBED_DIM)
        self.decoder_embed = nn.Embedding(CODON_VOCAB_SIZE, EMBED_DIM)

        self.transformer = nn.Transformer(
            d_model=EMBED_DIM,
            nhead=NUM_HEADS,
            num_encoder_layers=NUM_LAYERS,
            num_decoder_layers=NUM_LAYERS,
            batch_first=True
        )

        self.fc_out = nn.Linear(128, CODON_VOCAB_SIZE)

    def generate_padding_mask(self, seq, pad_token):
        return (seq == pad_token)

    def forward(self, src, tgt):
        src_mask = self.generate_padding_mask(src, pad_token=21)  # amino <PAD>
        tgt_mask = self.generate_padding_mask(tgt, pad_token=65)  # codon <PAD>

        src_embed = self.encoder_embed(src)
        tgt_embed = self.decoder_embed(tgt)

        out = self.transformer(
            src_embed,
            tgt_embed,
            src_key_padding_mask=src_mask,
            tgt_key_padding_mask=tgt_mask,
            memory_key_padding_mask=src_mask
        )
        return self.fc_out(out)

In [None]:
model = ProteinToDNATransformer().to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=65)  # 65 = <PAD>
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
os.makedirs("/content/drive/MyDrive/models", exist_ok=True)
losses = []
best_loss = float('inf')
patience_counter = 0

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)

        tgt_input = tgt[:, :-1]  # input starts with <SOS>
        tgt_output = tgt[:, 1:]  # output ends with <EOS>

        optimizer.zero_grad()
        output = model(src, tgt_input)
        loss = criterion(output.view(-1, CODON_VOCAB_SIZE), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {avg_loss:.4f}")

    # Early stopping
    if avg_loss < best_loss:
        best_loss = avg_loss
        patience_counter = 0
        torch.save(model.state_dict(), "/content/drive/MyDrive/models/best_transformer_with_sos.pt")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break


Epoch 1/300 - Loss: 2.0136
Epoch 2/300 - Loss: 1.3622
Epoch 3/300 - Loss: 1.2088
Epoch 4/300 - Loss: 1.1822
Epoch 5/300 - Loss: 1.1706
Epoch 6/300 - Loss: 1.1634
Epoch 7/300 - Loss: 1.1580
Epoch 8/300 - Loss: 1.1532
Epoch 9/300 - Loss: 1.1479
Epoch 10/300 - Loss: 1.1440
Epoch 11/300 - Loss: 1.1409
Epoch 12/300 - Loss: 1.1387
Epoch 13/300 - Loss: 1.1369
Epoch 14/300 - Loss: 1.1350
Epoch 15/300 - Loss: 1.1332
Epoch 16/300 - Loss: 1.1315
Epoch 17/300 - Loss: 1.1301
Epoch 18/300 - Loss: 1.1293
Epoch 19/300 - Loss: 1.1278
Epoch 20/300 - Loss: 1.1267
Epoch 21/300 - Loss: 1.1258
Epoch 22/300 - Loss: 1.1252
Epoch 23/300 - Loss: 1.1243
Epoch 24/300 - Loss: 1.1235
Epoch 25/300 - Loss: 1.1230
Epoch 26/300 - Loss: 1.1224
Epoch 27/300 - Loss: 1.1217
Epoch 28/300 - Loss: 1.1213
Epoch 29/300 - Loss: 1.1207
Epoch 30/300 - Loss: 1.1200
Epoch 31/300 - Loss: 1.1196
Epoch 32/300 - Loss: 1.1191
Epoch 33/300 - Loss: 1.1189
Epoch 34/300 - Loss: 1.1182
Epoch 35/300 - Loss: 1.1175
Epoch 36/300 - Loss: 1.1172
E