<a href="https://colab.research.google.com/github/phodmin/CodonConcierge/blob/main/ChatRNA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# The control panel
setup_colab_environment()
run_util3()




NameError: ignored

In [None]:
import os
import subprocess

src_vocab_size = 22 # Amino acids (20 + '*' fors stop + 'X')
tgt_vocab_size = 65 # Codons (64 + 1 'X' for padded, i.e. unknown codons)

# Model Configurations
MODEL_CONFIGS = {
    "small": {"d_model": 128, "num_heads": 4, "num_layers": 2, "d_ff": 512, "dropout": 0.1},
    "medium": {"d_model": 256, "num_heads": 8, "num_layers": 4, "d_ff": 1024, "dropout": 0.1},
    "large": {"d_model": 512, "num_heads": 8, "num_layers": 6, "d_ff": 2048, "dropout": 0.1},
    "wide": {"d_model": 1024, "num_heads": 16, "num_layers": 2, "d_ff": 4096, "dropout": 0.1},
    "shallow_multihead": {"d_model": 256, "num_heads": 16, "num_layers": 2, "d_ff": 1024, "dropout": 0.1},
    "deep_narrow": {"d_model": 128, "num_heads": 4, "num_layers": 10, "d_ff": 512, "dropout": 0.1}
}

DIFFUSION_CONFIGS = {
    "small": {"hidden_units": 128, "num_layers": 2, "dropout": 0.1, "num_diffusion_steps": 50},
    "diffusion_medium": {"hidden_units": 256, "num_layers": 4, "dropout": 0.1, "num_diffusion_steps": 100},
    "diffusion_extended_steps": {"hidden_units": 256, "num_layers": 4, "dropout": 0.1, "num_diffusion_steps": 200},
    "diffusion_deep": {"hidden_units": 128, "num_layers": 8, "dropout": 0.1, "num_diffusion_steps": 100},
}

def setup_colab_environment():
    # File Path
    gencode_source_file_path = './data/gencode/gencode.v44.pc_transcripts.fa'

    # Check if the data file exists
    if not os.path.exists(gencode_source_file_path):
        # Create a directory if it doesn't exist
        subprocess.run("mkdir -p data/gencode", shell=True)

        # Download the file
        subprocess.run("wget -O data/gencode/gencode.v44.pc_transcripts.fa.gz https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_44/gencode.v44.pc_transcripts.fa.gz", shell=True)

        # Decompress the downloaded file
        subprocess.run("gunzip data/gencode/gencode.v44.pc_transcripts.fa.gz", shell=True)

    # Check if Bio and tensorboardX are installed
    try:
        import Bio
        import tensorboardX
    except ImportError:
        # Install required Python packages
        subprocess.run("pip install Bio", shell=True)
        subprocess.run("pip install tensorboardX", shell=True)



In [None]:
def run_util3():
    # @util3.py
    import os
    from Bio import SeqIO
    from Bio.Seq import Seq
    import pandas as pd
    from io import StringIO
    from typing import Tuple, List
    from itertools import compress
    from torch.nn.utils.rnn import pad_sequence
    import pytest

    # Define bases
    bases = ['A', 'T', 'G', 'C', 'N']

    def print_success(message):
        # ANSI color codes
        GREEN = "\033[92m"
        BOLD = "\033[1m"
        RESET = "\033[0m"

        # Emoji and styled message
        success_emoji = "✅"
        styled_message = f"{GREEN}{BOLD}{success_emoji} {message} {success_emoji}{RESET}"

        # Border
        border_length = len(message) + 4 + 4  # 4 for spaces & emoji, 4 for bold ANSI code characters
        border = "+" + "-" * border_length + "+"

        # Print the styled output
        print(border)
        print(f"| {styled_message} |")
        print(border)

    # Mapping codons to amino acids, standard capitalised IUPAC codes
    # Padded codons (any that include N) are mapped to 'X'
    codon_to_aa = {
        'ATA':'I', 'ATC':'I', 'ATT':'I', 'ATG':'M',
        'ACA':'T', 'ACC':'T', 'ACG':'T', 'ACT':'T',
        'AAC':'N', 'AAT':'N', 'AAA':'K', 'AAG':'K',
        'AGC':'S', 'AGT':'S', 'AGA':'R', 'AGG':'R',
        'CTA':'L', 'CTC':'L', 'CTG':'L', 'CTT':'L',
        'CCA':'P', 'CCC':'P', 'CCG':'P', 'CCT':'P',
        'CAC':'H', 'CAT':'H', 'CAA':'Q', 'CAG':'Q',
        'CGA':'R', 'CGC':'R', 'CGG':'R', 'CGT':'R',
        'GTA':'V', 'GTC':'V', 'GTG':'V', 'GTT':'V',
        'GCA':'A', 'GCC':'A', 'GCG':'A', 'GCT':'A',
        'GAC':'D', 'GAT':'D', 'GAA':'E', 'GAG':'E',
        'GGA':'G', 'GGC':'G', 'GGG':'G', 'GGT':'G',
        'TCA':'S', 'TCC':'S', 'TCG':'S', 'TCT':'S',
        'TTC':'F', 'TTT':'F', 'TTA':'L', 'TTG':'L',
        'TAC':'Y', 'TAT':'Y', 'TAA':'*', 'TAG':'*',
        'TGC':'C', 'TGT':'C', 'TGA':'*', 'TGG':'W',
        # Adding the padded codons
        'ANN':'X', 'CNN':'X', 'GNN':'X', 'TNN':'X',
        'AAN':'X', 'CAN':'X', 'GAN':'X', 'TAN':'X',
        'ANA':'X', 'CNA':'X', 'GNA':'X', 'TNA':'X',
        'ANC':'X', 'CNC':'X', 'GNC':'X', 'TNC':'X',
        'ANG':'X', 'CNG':'X', 'GNG':'X', 'TNG':'X',
        'ANT':'X', 'CNT':'X', 'GNT':'X', 'TNT':'X',
        'AGN':'X', 'CGN':'X', 'GGN':'X', 'TGN':'X',
        'ATN':'X', 'CTN':'X', 'GTN':'X', 'TTN':'X',
        'ACN':'X', 'CCN':'X', 'GCN':'X', 'TCN':'X',
        'NAA':'X', 'NAC':'X', 'NAG':'X', 'NAT':'X',
        'NCA':'X', 'NCC':'X', 'NCG':'X', 'NCT':'X',
        'NGA':'X', 'NGC':'X', 'NGG':'X', 'NGT':'X',
        'NTA':'X', 'NTC':'X', 'NTG':'X', 'NTT':'X',
        'NAN':'X', 'NCN':'X', 'NGN':'X', 'NTN':'X',
        'NNN':'X'
    }

    # Mapping amino acids to integers, 1-20
    # Unknown Amino Acid ('X') is mapped to '0'
    # Stop codon ('*') is mapped to '21'
    aa_to_int = {
        'A': 1, 'C': 2, 'D': 3, 'E': 4,
        'F': 5, 'G': 6, 'H': 7, 'I': 8,
        'K': 9, 'L':10, 'M':11, 'N':12,
        'P':13, 'Q':14, 'R':15, 'S':16,
        'T':17, 'V':18, 'W':19, 'Y':20,
        # Unknown amino acid ('X') and stop codon ('*')
        'X': 0, '*':21
    }

    # Mapping codons to ints, 1-64
    # Padded codons (any that include N) are mapped to '0'
    codon_to_int = {
        'ATA': 1, 'ATC': 2, 'ATT': 3, 'ATG': 4,
        'ACA': 5, 'ACC': 6, 'ACG': 7, 'ACT': 8,
        'AAT': 9, 'AAC':10, 'AAA':11, 'AAG':12,
        'AGA':13, 'AGC':14, 'AGG':15, 'AGT':16,
        'CTA':17, 'CTC':18, 'CTT':19, 'CTG':20,
        'CCA':21, 'CCC':22, 'CCG':23, 'CCT':24,
        'CAT':25, 'CAC':26, 'CAA':27, 'CAG':28,
        'CGA':29, 'CGC':30, 'CGG':31, 'CGT':32,
        'GTA':33, 'GTC':34, 'GTT':35, 'GTG':36,
        'GCA':37, 'GCC':38, 'GCG':39, 'GCT':40,
        'GAT':41, 'GAC':42, 'GAA':43, 'GAG':44,
        'GGA':45, 'GGC':46, 'GGG':47, 'GGT':48,
        'TCA':49, 'TCC':50, 'TCT':51, 'TCG':52,
        'TTA':53, 'TTC':54, 'TTT':55, 'TTG':56,
        'TAT':57, 'TAC':58, 'TAA':59, 'TAG':60,
        'TGA':61, 'TGC':62, 'TGG':63, 'TGT':64,
        # Adding the padded codons
        'ANN':0, 'CNN':0, 'GNN':0, 'TNN':0,
        'AAN':0, 'CAN':0, 'GAN':0, 'TAN':0,
        'ANA':0, 'CNA':0, 'GNA':0, 'TNA':0,
        'ANC':0, 'CNC':0, 'GNC':0, 'TNC':0,
        'ANG':0, 'CNG':0, 'GNG':0, 'TNG':0,
        'ANT':0, 'CNT':0, 'GNT':0, 'TNT':0,
        'AGN':0, 'CGN':0, 'GGN':0, 'TGN':0,
        'ATN':0, 'CTN':0, 'GTN':0, 'TTN':0,
        'ACN':0, 'CCN':0, 'GCN':0, 'TCN':0,
        'NAA':0, 'NAC':0, 'NAG':0, 'NAT':0,
        'NCA':0, 'NCC':0, 'NCG':0, 'NCT':0,
        'NGA':0, 'NGC':0, 'NGG':0, 'NGT':0,
        'NTA':0, 'NTC':0, 'NTG':0, 'NTT':0,
        'NAN':0, 'NCN':0, 'NGN':0, 'NTN':0,
        'NNN':0
    }


    # 0
    def load_src_tgt_sequences(source_file: str, max_seq_length: int = 120000) -> Tuple[List[List[int]], List[List[int]]]:
        """
        Load source and target sequences from a FASTA file and encode them into numerical sequences.

        Args:
            source_file (str): Path to the source FASTA file.
            max_seq_length (int): Maximum length of the target sequences in nucleotides.

        Returns:
            Tuple of two numpy arrays:
            - aa_enc: Encoded amino acid sequences.
            - codon_enc: Encoded codon sequences.
        """
        # Input validation
        if not os.path.exists(source_file):
            raise FileNotFoundError(f"Source file {source_file} not found")

        df = parse_fasta(source_file)

        # Data extraction
        df = extract_cds_columns(df)
        aa_seqs, codon_seqs = extract_sequences(df)

        # Filter sequences based on max_seq_length
        valid_seq_mask = [(len(seq) * 3 <= max_seq_length) for seq in codon_seqs]
        aa_seqs = list(compress(aa_seqs, valid_seq_mask))
        codon_seqs = list(compress(codon_seqs, valid_seq_mask))

        # Sequence encoding
        aa_enc = encode_amino_sequence(aa_seqs)
        codon_enc = encode_codon_sequence(codon_seqs)

        return aa_enc, codon_enc

    # 1
    def parse_fasta(fasta_file):
        records = list(SeqIO.parse(fasta_file, "fasta"))
        parsed_records = []
        for record in records:
            header_parts = record.description.split("|")
            transcript_info = {
                "transcript_id": header_parts[0],
                "gene_id": header_parts[1],
                "manual_gene_id": header_parts[2],
                "manual_transcript_id": header_parts[3],
                "gene_symbol_variant": header_parts[4],
                "gene_name": header_parts[5],
                "sequence_length": int(header_parts[6]),
                "UTR5": header_parts[7].split(":")[1] if len(header_parts) > 7 and "UTR5" in header_parts[7] else None,
                "CDS": header_parts[8].split(":")[1] if len(header_parts) > 8 and "CDS" in header_parts[8] else None,
                "UTR3": header_parts[9].split(":")[1] if len(header_parts) > 9 and "UTR3" in header_parts[9] else None,

                "sequence": str(record.seq)
            }
            parsed_records.append(transcript_info)

        df = pd.DataFrame(parsed_records)
        return df

    # 2
    def extract_cds_columns(df):
        """Extract CDS start/end columns"""

        # Split the 'CDS' column once
        cds_splits = df['CDS'].str.split('-')

        # Check if all rows have exactly two parts after splitting
        valid_format = cds_splits.apply(lambda x: len(x) == 2 if x else False)

        # For rows with the valid 'start-end' format
        df.loc[valid_format, 'cds_start'] = cds_splits[valid_format].str[0].astype(int)
        df.loc[valid_format, 'cds_end'] = cds_splits[valid_format].str[1].astype(int)

        # For rows without the valid 'start-end' format or if 'CDS' is not found
        default_indices = ~valid_format | df['CDS'].isna()
        df.loc[default_indices, 'cds_start'] = 1
        df.loc[default_indices, 'cds_end'] = df.loc[default_indices, 'sequence'].str.len()

        # Ensure 'cds_start' and 'cds_end' are integers
        df['cds_start'] = df['cds_start'].astype(int)
        df['cds_end'] = df['cds_end'].astype(int)

        valid_rows = (df['cds_start'] > 0) & (df['cds_end'] <= df['sequence'].str.len())
        valid_df = df[valid_rows]
        return valid_df

    # 3

    # 3.1 Codons -> amino acids
    def translate_codons_to_amino_acids(codon_seqs: List[str]) -> List[str]:
        """
        Translate a list of codon sequences to their corresponding amino acid sequences.

        If the codon sequence length isn't a multiple of 3, it will be padded with 'N'
        to the nearest multiple of 3.

        Parameters:
        - codon_seqs (List[str]): A list of codon sequences.
                                  Each codon is expected to be a triplet of nucleotide bases.

        Returns:
        - List[str]: A list of amino acid sequences corresponding to the input codon sequences.

        Raises:
        - ValueError: If a provided codon is not recognized.
        """

        result = []

        for seq in codon_seqs:
            # Pad with 'N' if not multiple of 3
            while len(seq) % 3 != 0:
                seq += 'N'

            amino_acid_seq = ""
            for i in range(0, len(seq), 3):
                codon = seq[i:i+3]
                if codon not in codon_to_aa:
                    raise ValueError(f"Unrecognized codon: {codon}")
                amino_acid_seq += codon_to_aa[codon]

            result.append(amino_acid_seq)

        return result

    # 3.2 Amino acids -> ints
    def translate_amino_acids_to_ints(aa_seqs: List[str]) -> List[List[int]]:
        """
        Translate a list of amino acid sequences to their corresponding integer sequences.

        Parameters:
        - aa_seqs (List[str]): A list of amino acid sequences.
                              Each amino acid is represented as a single character.

        Returns:
        - List[List[int]]: A list of integer sequences corresponding to the input amino acid sequences.

        Raises:
        - ValueError: If a provided amino acid is not recognized.
        """

        result = []

        for seq in aa_seqs:
            int_seq = []
            for aa in seq:
                if aa not in aa_to_int:
                    raise ValueError(f"Unrecognized amino acid: {aa}")
                int_seq.append(aa_to_int[aa])

            result.append(int_seq)

        return result

    # 3.3 Codons -> ints
    def translate_codons_to_ints(codon_seqs: List[str]) -> List[int]:
        """
        Translate a list of codon sequences to their corresponding integer values.

        If the codon sequence length isn't a multiple of 3, it will be padded with 'N'
        to the nearest multiple of 3.

        Parameters:
        - codon_seqs (List[str]): A list of codon sequences.
                                  Each codon is expected to be a triplet of nucleotide bases.

        Returns:
        - List[int]: A list of integer values corresponding to the input codon sequences.

        Raises:
        - ValueError: If a provided codon is not recognized.
        """

        result = []

        for seq in codon_seqs:
            # Pad with 'N' if not multiple of 3
            while len(seq) % 3 != 0:
                seq += 'N'

            int_values = []
            for i in range(0, len(seq), 3):
                codon = seq[i:i+3]
                if codon not in codon_to_int:
                    raise ValueError(f"Unrecognized codon: {codon}")
                int_values.append(codon_to_int[codon])

            result.append(int_values)

        return result

    # 4
    def extract_sequences(df) -> Tuple[List[List[int]], List[List[int]]]:
        """
        Extracts amino acid and codon sequences from the 'sequence' field in the DataFrame.

        Args:
            df: A pandas DataFrame containing the 'sequence', 'cds_start', and 'cds_end' columns.

        Returns:
            A tuple containing two lists:
            - aa_seqs_int: A list of amino acid sequences as integers.
            - codon_seqs_int: A list of codon sequences as integers.
        """

        aa_seqs_int = []   # For storing amino acid sequences as integers
        codon_seqs_int = []   # For storing codon sequences as integers

        for _, row in df.iterrows():
            seq = row['sequence'][row['cds_start']-1:row['cds_end']]  # -1 because Python is 0-based

            # Extracting codons
            codons = [seq[i:i+3] for i in range(0, len(seq), 3) if 1 <= len(seq[i:i+3]) <= 3]

            # Getting the amino acid integer sequences
            aa_seqs = translate_codons_to_amino_acids(codons)
            aa_ints = [aa for seq in translate_amino_acids_to_ints(aa_seqs) for aa in seq]

            # Getting the codon integer sequences
            codon_ints = [codon for seq in translate_codons_to_ints(codons) for codon in seq]

            codon_seqs_int.append(codon_ints)
            aa_seqs_int.append(aa_ints)

        return aa_seqs_int, codon_seqs_int


    # 5 Dummy encode functions
    def encode_amino_sequence(aa_seqs: List[List[int]]) -> List[List[int]]:
        return aa_seqs

    def encode_codon_sequence(codon_seqs: List[List[int]]) -> List[List[int]]:
        return codon_seqs

    # 6
    def collate_fn(batch):
        src_sequences, tgt_sequences = zip(*batch)
        # Padding sequences
        src_sequences = pad_sequence(src_sequences, batch_first=True)
        tgt_sequences = pad_sequence(tgt_sequences, batch_first=True)
        return src_sequences, tgt_sequences

    # 7

    # 7.1 Ints -> Amino acids
    def translate_ints_to_amino_acids(int_seqs: List[List[int]]) -> List[str]:
        """
        Translate a list of integer sequences to their corresponding amino acid sequences.

        Parameters:
        - int_seqs (List[List[int]]): A list of integer sequences representing amino acids.

        Returns:
        - List[str]: A list of amino acid sequences corresponding to the input integer sequences.

        Raises:
        - ValueError: If a provided integer is not recognized.
        """
        int_to_aa = {v: k for k, v in aa_to_int.items()}

        result = []
        for seq in int_seqs:
            aa_seq = ""
            for i in seq:
                if i not in int_to_aa:
                    raise ValueError(f"Unrecognized integer: {i}")
                aa_seq += int_to_aa[i]
            result.append(aa_seq)
        return result

    # 7.2 Ints -> Codons
    def translate_ints_to_codons(int_seqs: List[List[int]]) -> List[str]:
        """
        Translate a list of integer sequences to their corresponding codon sequences.

        Parameters:
        - int_seqs (List[List[int]]): A list of integer sequences representing codons.

        Returns:
        - List[str]: A list of codon sequences corresponding to the input integer sequences.

        Raises:
        - ValueError: If a provided integer is not recognized.
        """
        int_to_codon = {v: k for k, v in codon_to_int.items()}

        result = []
        for seq in int_seqs:
            codon_seq = ""
            for i in seq:
                if i not in int_to_codon:
                    raise ValueError(f"Unrecognized integer: {i}")
                codon_seq += int_to_codon[i]
            result.append(codon_seq)
        return result

# codonFormer.py


In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import math
import time
from datetime import datetime
from tensorboardX import SummaryWriter
from sklearn.model_selection import train_test_split
#from util import *
#from util3 import *


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()

        # GPU Specific
        # Ensure the nopeak_mask is on the same device as tgt_mask
        nopeak_mask = nopeak_mask.to(tgt_mask.device)

        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

class SequenceDataset(torch.utils.data.Dataset):
    def __init__(self, src_sequences, tgt_sequences):
        assert len(src_sequences) == len(tgt_sequences), "Source and target sequences must have the same length."
        self.src_sequences = src_sequences
        self.tgt_sequences = tgt_sequences

    def __len__(self):
        return len(self.src_sequences)

    def __getitem__(self, index):

        src_sequence = encode_amino_sequence(self.src_sequences[index])

        tgt_sequence = encode_codon_sequence(self.tgt_sequences[index])
        return torch.tensor(src_sequence), torch.tensor(tgt_sequence)

def get_gpu_memory_usage(device):
    return torch.cuda.memory_allocated(device) / 1e6  # Convert bytes to MB

def train_model(model, dataloader, val_dataloader, tgt_vocab_size, epochs=100, lr=0.0001, verbose=False, start_epoch=0, optimizer_state_dict=None):
    current_memory_usage = 0
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is used for padding.
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)

    # Load optimizer state if provided
    if optimizer_state_dict:
        optimizer.load_state_dict(optimizer_state_dict)

    model.train()

    for epoch in range(start_epoch, epochs):  # Adjusted loop range
        if verbose:
            start_time = time.time()
            current_memory_usage = get_gpu_memory_usage(device)  # Assuming `device` is defined and passed to the function

        for batch_idx, (src_data, tgt_data) in enumerate(dataloader):
            optimizer.zero_grad()
            src_data, tgt_data = src_data.to(device), tgt_data.to(device)
            output = model(src_data, tgt_data[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
            loss.backward()

            if verbose:
                grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

            optimizer.step()

        # Computing the Validation Loss every 5th epoch
        if verbose and (epoch + 1) % 5 == 0:
            val_loss = validate(model, val_dataloader, criterion, tgt_vocab_size)
            end_time = time.time()
            current_memory_usage = torch.cuda.memory_allocated(device) / 1e6  # in MBs
            print(f"\033[1;44mEpoch: {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss}\033[0m")
            print(f"\tTime taken for Epoch {epoch+1}: {end_time - start_time:.2f} seconds")
            print(f"\tGradient Norm: {grad_norm:.2f}")
            if 'current_memory_usage' in locals():
                print(f"\tMemory Usage: {current_memory_usage:.2f} MBs")
        elif verbose:
            # Print without validation loss
            end_time = time.time()
            print(f"\033[1mEpoch: {epoch+1}, Training Loss: {loss.item()}\033[0m")
            print(f"\tTime taken for Epoch {epoch+1}: {end_time - start_time:.2f} seconds")
            print(f"\tGradient Norm: {grad_norm:.2f}")
            if 'current_memory_usage' in locals():
                print(f"\tMemory Usage: {current_memory_usage:.2f} MBs")

def validate(model, dataloader, criterion, tgt_vocab_size):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (src_data, tgt_data) in enumerate(dataloader):
            src_data, tgt_data = src_data.to(device), tgt_data.to(device)
            output = model(src_data, tgt_data[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
            total_loss += loss.item()
    model.train()
    return total_loss / len(dataloader)

# def predict(model, src_sequence):
#     model.eval()
#     with torch.no_grad():
#         src_sequence = torch.tensor(encode_amino_sequence(src_sequence)).unsqueeze(0) # assuming the encoding function is available
#         tgt_start_token = torch.tensor([SOME_START_TOKEN_INDEX])  # You'll need a start token for decoding
#         output = model(src_sequence, tgt_start_token)
#         predicted = output.argmax(dim=-1)
#         return decode_codon_sequence(predicted[0]) # assuming a decoding function is available




In [None]:
from datetime import datetime
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def initialise_training(max_seq_length=2000, config="large", batch=64, n_epochs=100, gencode_source_file_path=None):
    """
    Initialise the training process.

    Parameters:
    - max_seq_length (int): Longest allowable sequence
    - config (str): Model configuration
    - batch (int): Batch size
    - n_epochs (int): Number of epochs
    - gencode_source_file_path (str): Path to the source file

    Returns:
    - train_dataloader, val_dataloader, transformer, device
    """
    print_success("Done setting the constants.")

    # Define, set, and check the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Get current date and time
    current_time = datetime.now()
    timestamp = current_time.strftime('%Y/%m/%d_%H:%M:%S')
    run_name = "saturday_big_model"
    writer = SummaryWriter('runs/' + run_name)

    src_sequences, tgt_sequences = load_src_tgt_sequences(source_file=gencode_source_file_path, max_seq_length=max_seq_length)
    src_train, src_val, tgt_train, tgt_val = train_test_split(src_sequences, tgt_sequences, test_size=0.1)

    print("Done loading the data.")

    # Filtering training and validation data
    filtered_train_data = [(src, tgt) for src, tgt in zip(src_train, tgt_train) if len(src) <= max_seq_length]
    filtered_val_data = [(src, tgt) for src, tgt in zip(src_val, tgt_val) if len(src) <= max_seq_length]
    print(f"Filtered from {len(src_train)} to {len(filtered_train_data)} training sequences.")
    print(f"Filtered from {len(src_val)} to {len(filtered_val_data)} validation sequences.")
    src_train, tgt_train = zip(*filtered_train_data)
    src_val, tgt_val = zip(*filtered_val_data)

    # Create training and validation datasets and dataloaders
    train_dataset = SequenceDataset(src_train, tgt_train)
    train_dataloader = DataLoader(train_dataset, batch_size=batch, shuffle=True, collate_fn=collate_fn)
    print("Done creating the training dataloader.")

    val_dataset = SequenceDataset(src_val, tgt_val)
    val_dataloader = DataLoader(val_dataset, batch_size=batch, shuffle=False, collate_fn=collate_fn)
    print("Done creating the validation dataloader.")

    config_params = MODEL_CONFIGS[config]
    print(f"Done setting - Model Configuration: {config_params}")
    transformer = Transformer(src_vocab_size, tgt_vocab_size, **config_params, max_seq_length=max_seq_length)

    transformer = transformer.to(device)
    print("Model loaded to the device (GPU).")

    return train_dataloader, val_dataloader, transformer, device


# Training

In [None]:
if __name__ == "__main__":
    max_seq_length = 200 # Longest allowable sequence
    config = "large" # Hyperparameters
    batch = 256 # Size: 64, 128, 256, 512, 1024, 2048
    # (A100 runs out of memory at 2048 + medium config)
    # (A100 cannot do max_seq_length = 2000 & large at more than 64 batch size)
    n_epochs = 100 # Choose num of epochs

    print_success("Done setting the constants.")

    # Define, set, and check the device

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device: " + str(device))

    # Get the current date and time
    current_time = datetime.now()

    # Format the timestamp as year_month_day_hour_minute_second
    timestamp = current_time.strftime('%Y/%m/%d_%H:%M:%S')

    # Initialize a name for the training run
    #run_name = timestamp + "_LengthCap-" + str(max_seq_length) + "_HPsize-" + config + "_Batch-" + str(batch)
    run_name = "saturday_big_model"
    # Initialize a writer
    writer = SummaryWriter('runs/' + run_name)

    # After training, navigate to the parent directory of 'runs' in the terminal.
    # Start TensorBoard by running the following command:
    # tensorboard --logdir=runs
    #
    # Then, open a web browser and navigate to the URL displayed in the terminal (usually http://localhost:6006/).
    # You should be able to see your logs and navigate between different experiments using the unique names.

    #src_sequences, tgt_sequences = load_src_tgt_sequences()
    src_sequences, tgt_sequences = load_src_tgt_sequences(source_file=gencode_source_file_path,max_seq_length=max_seq_length)

    # Split the data into training and validation sets
    src_train, src_val, tgt_train, tgt_val = train_test_split(src_sequences, tgt_sequences, test_size=0.1)

    print("Done loading the data.")

    # Filtering the training data
    filtered_train_data = [(src, tgt) for src, tgt in zip(src_train, tgt_train) if len(src) <= max_seq_length]
    print(f"Filtered from {len(src_train)} to {len(filtered_train_data)} training sequences.")
    src_train, tgt_train = zip(*filtered_train_data)

    # Filtering the validation data
    filtered_val_data = [(src, tgt) for src, tgt in zip(src_val, tgt_val) if len(src) <= max_seq_length]
    print(f"Filtered from {len(src_val)} to {len(filtered_val_data)} validation sequences.")
    src_val, tgt_val = zip(*filtered_val_data)

    # Create training dataset and dataloader
    train_dataset = SequenceDataset(src_train, tgt_train)
    train_dataloader = DataLoader(train_dataset, batch_size=batch, shuffle=True, collate_fn=collate_fn)
    print("Done creating the training dataloader.")

    # Create validation dataset and dataloader
    val_dataset = SequenceDataset(src_val, tgt_val)
    val_dataloader = DataLoader(val_dataset, batch_size=batch, shuffle=False, collate_fn=collate_fn)  # No need to shuffle the validation data
    print("Done creating the validation dataloader.")

    # Create the transformer model using the chosen configuration
    config_params = MODEL_CONFIGS[config]
    print(f"Done setting - Model Configuration: {config_params}")
    transformer = Transformer(src_vocab_size, tgt_vocab_size, **config_params, max_seq_length=max_seq_length)

    # Move the model to GPU if available
    transformer = transformer.to(device)
    print("Model loaded to the device (GPU).")

    # Train the model
    print_success("INITIALISE TRAINING.")
    train_model(transformer, train_dataloader, val_dataloader, tgt_vocab_size=tgt_vocab_size, epochs=n_epochs, verbose=True)

    # Close the writer
    writer.close()

    # Report the number of params
    num_params = sum(p.numel() for p in transformer.parameters())
    print_success(f"Training finished! The model has {num_params} parameters.")

    # Save the model to a file
    save_path = 'model_' + run_name
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    torch.save(transformer.state_dict(), os.path.join(save_path, 'my_model.pt'))
    torch.save(transformer, os.path.join(save_path, 'my_model_complete.pt'))

    from google.colab import drive
    drive.mount('/content/gdrive')
    import shutil

    # Source paths
    #src_path1 = '/content/runs/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024'
    #src_path2 = '/content/model_2023'
    src_path1 = '/content/runs/saturday_big_model'
    src_path2 = '/content/model_saturday_big_model'


    # Destination paths
    dst_path1 = '/content/gdrive/My Drive/Thesis/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024'
    dst_path2 = '/content/gdrive/My Drive/Thesis/model_2023'

    # Copy the directories to the destination paths
    shutil.copytree(src_path1, dst_path1)
    shutil.copytree(src_path2, dst_path2)


## Training output

In [None]:
import torch

model_save_path = "/content/model_checkpoint.pth"
#optimizer_save_path = "/content/optimizer_checkpoint.pth"

torch.save({
    'epoch': 47,  # current epoch number
    'model_state_dict': transformer.state_dict(),
    #'optimizer_state_dict': optimizer.state_dict(),  # assuming you named your optimizer 'optimizer'
    'loss': 0.5647182464599609,  # you can save the latest loss value
}, model_save_path)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
# Change 'MyDrive/checkpoints/' to your desired path within Google Drive
model_save_path = "/content/drive/MyDrive/checkpoints/model_checkpoint.pth"
import os
save_directory = "/content/drive/MyDrive/checkpoints/"
if not os.path.exists(save_directory):
    os.makedirs(save_directory)

In [None]:
#from google.colab import files

# First, save to Colab VM
model_save_path = "/content/model_checkpoint.pth"
torch.save({
    'epoch': 47,
    'model_state_dict': transformer.state_dict(),
    #'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.5647182464599609,
}, model_save_path)

# Then, download
files.download(model_save_path)

# Resume training

In [None]:
import torch
import os

def train_and_save_model(model, dataloader, val_dataloader, tgt_vocab_size, epochs=100, lr=0.0001, verbose=False, start_epoch=0, optimizer_state_dict=None, save_path="/content/"):
    def save_checkpoint(epoch, model, filename, optimizer= None):
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, filename)

    # Train the model
    train_model(model, dataloader, val_dataloader, tgt_vocab_size, epochs, lr, verbose, start_epoch, optimizer_state_dict)

    # Save after every epoch
    for epoch in range(start_epoch, epochs):
        filename = os.path.join(save_path, f"checkpoint_epoch_{epoch+1}.pth")
        save_checkpoint(epoch, model, filename)

def initialize_or_load_checkpoint(model_path, model):
    """
    Initializes or loads a checkpoint based on the presence of the specified model path.
    Returns the model, optimizer state and starting epoch.
    """
    start_epoch = 0
    optimizer_state_dict = None

    if os.path.exists(model_path):
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        if 'optimizer_state_dict' in checkpoint:
            optimizer_state_dict = checkpoint['optimizer_state_dict']
        start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

    return model, start_epoch, optimizer_state_dict

train_dataloader, val_dataloader, transformer, device = initialise_training(gencode_source_file_path=gencode_source_file_path)

# Initialize or load checkpoint
model_path = "/content/model_checkpoint.pth"
transformer, start_epoch, optimizer_state_dict = initialize_or_load_checkpoint(model_path, transformer)

# Redefine number of epochs
n_epochs = 100

# Adjust the training call to pass optimizer state dict if available
train_and_save_model(transformer, train_dataloader, val_dataloader, tgt_vocab_size=tgt_vocab_size, epochs=n_epochs, verbose=True, start_epoch=start_epoch, optimizer_state_dict=optimizer_state_dict)

In [None]:
if __name__ == "__main__":
    max_seq_length = 200 # Longest allowable sequence
    config = "large" # Hyperparameters
    batch = 256 # Size: 64, 128, 256, 512, 1024, 2048 (A100 runs out of memory at 2048 + medium config)
    n_epochs = 100 # Choose num of epochs

    print_success("Done setting the constants.")

    # Define, set, and check the device

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device: " + str(device))

    # Get the current date and time
    current_time = datetime.now()

    # Format the timestamp as year_month_day_hour_minute_second
    timestamp = current_time.strftime('%Y/%m/%d_%H:%M:%S')

    # Initialize a name for the training run
    #run_name = timestamp + "_LengthCap-" + str(max_seq_length) + "_HPsize-" + config + "_Batch-" + str(batch)
    run_name = "saturday_big_model"
    # Initialize a writer
    writer = SummaryWriter('runs/' + run_name)
    # Filtering the validation data
    filtered_val_data = [(src, tgt) for src, tgt in zip(src_val, tgt_val) if len(src) <= max_seq_length]
    print(f"Filtered from {len(src_val)} to {len(filtered_val_data)} validation sequences.")
    src_val, tgt_val = zip(*filtered_val_data)

    # Create training dataset and dataloader
    train_dataset = SequenceDataset(src_train, tgt_train)
    train_dataloader = DataLoader(train_dataset, batch_size=batch, shuffle=True, collate_fn=collate_fn)
    print("Done creating the training dataloader.")

    # Create validation dataset and dataloader
    val_dataset = SequenceDataset(src_val, tgt_val)
    val_dataloader = DataLoader(val_dataset, batch_size=batch, shuffle=False, collate_fn=collate_fn)  # No need to shuffle the validation data
    print("Done creating the validation dataloader.")

    # Create the transformer model using the chosen configuration
    config_params = MODEL_CONFIGS[config]
    print(f"Done setting - Model Configuration: {config_params}")
    transformer = Transformer(src_vocab_size, tgt_vocab_size, **config_params, max_seq_length=max_seq_length)

    # Move the model to GPU if available
    transformer = transformer.to(device)
    print("Model loaded to the device (GPU).")

    ##########     Train the model    ##########
    print_success("INITIALISE TRAINING.")
    train_model(transformer, train_dataloader, val_dataloader, tgt_vocab_size=tgt_vocab_size, epochs=n_epochs, verbose=True)

  ##########  ##########  ##########  ##########  ##########

    # Close the writer
    writer.close()

    # Report the number of params
    num_params = sum(p.numel() for p in transformer.parameters())
    print_success(f"Training finished! The model has {num_params} parameters.")

    # Save the model to a file
    save_path = 'model_' + run_name
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    torch.save(transformer.state_dict(), os.path.join(save_path, 'my_model.pt'))
    torch.save(transformer, os.path.join(save_path, 'my_model_complete.pt'))

    from google.colab import drive
    drive.mount('/content/gdrive')
    import shutil

    # Source paths
    #src_path1 = '/content/runs/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024'
    #src_path2 = '/content/model_2023'
    src_path1 = '/content/runs/saturday_big_model'
    src_path2 = '/content/model_saturday_big_model'


    # Destination paths
    dst_path1 = '/content/gdrive/My Drive/Thesis/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024'
    dst_path2 = '/content/gdrive/My Drive/Thesis/model_2023'

    # Copy the directories to the destination paths
    shutil.copytree(src_path1, dst_path1)
    shutil.copytree(src_path2, dst_path2)

# Infrastructure

In [None]:
torch.save(transformer.state_dict(), 'model_checkpoint.pth')
from google.colab import drive
drive.mount('/content/gdrive')
!cp model_checkpoint.pth '/content/gdrive/My Drive/'


In [None]:
#del transformer
torch.cuda.empty_cache()
import gc
gc.collect()
gc.collect()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

In [None]:

import tensorflow as tf
tf.test.gpu_device_name()

del transformer  # or any other variable
torch.cuda.empty_cache()

In [None]:
from google.colab import files

files.download('/content/runs/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024')  # replace 'path_to_file' with your file's path
files.download('/content/model_2023')

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:

import shutil

# Source paths
src_path1 = '/content/runs/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024'
src_path2 = '/content/model_2023'

# Destination paths
dst_path1 = '/content/gdrive/My Drive/Thesis/2023/09/14_02:19:50_LengthCap:500_HPsize:medium_Batch:1024'
dst_path2 = '/content/gdrive/My Drive/Thesis/model_2023'

# Copy the directories to the destination paths
shutil.copytree(src_path1, dst_path1)
shutil.copytree(src_path2, dst_path2)

# stableTranslation.py


In [None]:
# @transformerCodonConcierge.py


