# **DNA Sequences Generation using SeqGAN**

## **Importing Dependencies**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.distributions import Categorical
from sklearn.model_selection import train_test_split

## **Constants**

In [None]:
# Constants
SEQ_LENGTH = 56
VOCAB_SIZE = 4  # A, C, G, T
HIDDEN_DIM = 128
BATCH_SIZE = 32
NUM_EPOCHS = 100
G_PRETRAIN_EPOCHS = 50
D_PRETRAIN_EPOCHS = 50

## **Nucleotide mapping**

In [None]:
# Nucleotide mapping
NUCLEOTIDE_MAP = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
INV_NUCLEOTIDE_MAP = {v: k for k, v in NUCLEOTIDE_MAP.items()}

## **Load and preprocess data**

In [None]:
# Load and preprocess data
def load_data(file_path):
    df = pd.read_csv(file_path, header=None, names=['label', 'name', 'sequence'])
    sequences = df['sequence'].str.upper().tolist()
    # Convert sequences to integer indices
    seq_indices = [[NUCLEOTIDE_MAP[c] for c in seq if c in NUCLEOTIDE_MAP] for seq in sequences]
    # Filter sequences of correct length
    seq_indices = [seq for seq in seq_indices if len(seq) == SEQ_LENGTH]
    # One-hot encoding
    one_hot = np.zeros((len(seq_indices), SEQ_LENGTH, VOCAB_SIZE))
    for i, seq in enumerate(seq_indices):
        for j, idx in enumerate(seq):
            one_hot[i, j, idx] = 1
    return torch.tensor(one_hot, dtype=torch.float32)

In [None]:
# Constants
SEQ_LENGTH = 56
VOCAB_SIZE = 4  # A, C, G, T
NUCLEOTIDE_MAP = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

def load_data(file_path):
    """
    Load and preprocess DNA sequences from a CSV file.
    Args:
        file_path (str): Path to the CSV file.
    Returns:
        torch.Tensor: One-hot encoded sequences.
    Raises:
        ValueError: If no valid sequences are found or the file is malformed.
    """
    try:
        # Read CSV with flexible delimiter handling
        df = pd.read_csv(file_path, header=None, names=['label', 'name', 'sequence'])
        print(f"Loaded {len(df)} sequences from {file_path}")

        # Removing the tab '\t'
        seq = df['sequence'].str.upper().str.strip().tolist()
        new_seq = []
        for i in range(0, len(seq)):
          new_seq.append(seq[i][1:])

        new_seq_final = []
        for i in range(0, len(new_seq)):
          if new_seq[i][0] == '\t':
            new_seq_final.append(new_seq[i][1:])
          else:
            new_seq_final.append(new_seq[i])

        sequences = new_seq_final

        # Convert sequences to uppercase and strip whitespace
        print(f"Raw sequences count: {len(sequences)}")

        # Convert sequences to integer indices and filter invalid sequences
        seq_indices = []
        invalid_chars = set()
        skipped_sequences = 0
        for i, seq in enumerate(sequences):
            # Check if sequence is valid (not None and is a string)
            if not isinstance(seq, str):
                print(f"Skipping sequence at index {i}: Not a string (value: {seq})")
                skipped_sequences += 1
                continue
            try:
                indices = [NUCLEOTIDE_MAP[c] for c in seq if c in NUCLEOTIDE_MAP]
                # Only include sequences with valid length
                if len(indices) == SEQ_LENGTH:
                    seq_indices.append(indices)
                else:
                    print(f"Skipping sequence at index {i}: Length {len(indices)} (expected {SEQ_LENGTH}): {seq[:10]}...")
                    skipped_sequences += 1
            except KeyError as e:
                invalid_chars.add(str(e))
                print(f"Skipping sequence at index {i}: Invalid character {str(e)} in sequence: {seq[:10]}...")
                skipped_sequences += 1

        if invalid_chars:
            print(f"Invalid characters found: {invalid_chars}")
        if skipped_sequences > 0:
            print(f"Skipped {skipped_sequences} invalid or incorrect-length sequences")

        # Check if any valid sequences remain
        if not seq_indices:
            raise ValueError(f"No valid sequences found after filtering. Check input file for valid DNA sequences.")
        print(f"Valid sequences after filtering: {len(seq_indices)}")

        # One-hot encoding
        one_hot = np.zeros((len(seq_indices), SEQ_LENGTH, VOCAB_SIZE), dtype=np.float32)
        for i, seq in enumerate(seq_indices):
            for j, idx in enumerate(seq):
                one_hot[i, j, idx] = 1

        # Convert to torch tensor
        tensor_data = torch.tensor(one_hot, dtype=torch.float32)
        print(f"One-hot encoded tensor shape: {tensor_data.shape}")

        return tensor_data

    except FileNotFoundError:
        raise FileNotFoundError(f"Could not find file: {file_path}")
    except pd.errors.EmptyDataError:
        raise ValueError(f"File {file_path} is empty or malformed")
    except Exception as e:
        raise ValueError(f"Error processing file {file_path}: {str(e)}")

## **Generator model**

In [None]:
# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.lstm = nn.LSTM(VOCAB_SIZE, HIDDEN_DIM, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, hidden=None):
        output, hidden = self.lstm(x, hidden)
        output = self.fc(output)
        output = self.softmax(output)
        return output, hidden

## **Discriminator model**

In [None]:
# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv1d(VOCAB_SIZE, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(128 * 14, 256)
        self.fc2 = nn.Linear(256, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.transpose(1, 2)  # (batch, seq_len, vocab) -> (batch, vocab, seq_len)
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

## **Monte Carlo rollout for reward estimation**

In [None]:
# Monte Carlo rollout for reward estimation
def rollout(generator, seq, num_rollouts=10):
    rewards = []
    for _ in range(num_rollouts):
        current_seq = seq.clone()
        hidden = None
        for t in range(SEQ_LENGTH):
            if t < len(seq[0]):
                input_t = torch.zeros(seq.size(0), 1, VOCAB_SIZE).to(seq.device)
                input_t.scatter_(2, seq[:, t:t+1].unsqueeze(-1), 1)
            else:
                input_t = torch.zeros(seq.size(0), 1, VOCAB_SIZE).to(seq.device)
                input_t[:, :, 0] = 1  # Default to A
            probs, hidden = generator(input_t, hidden)
            dist = Categorical(probs.squeeze(1))
            next_nucleotide = dist.sample().unsqueeze(1)
            current_seq = torch.cat([current_seq[:, :t], next_nucleotide], dim=1)
        # Convert to one-hot for discriminator
        one_hot = torch.zeros(seq.size(0), SEQ_LENGTH, VOCAB_SIZE).to(seq.device)
        one_hot.scatter_(2, current_seq.unsqueeze(-1), 1)
        reward = discriminator(one_hot).detach()
        rewards.append(reward)
    return torch.mean(torch.stack(rewards), dim=0)

## **Loading Data**

In [None]:
# Loading data
data = load_data('data.csv')
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

Loaded 106 sequences from data.csv
Raw sequences count: 106
Valid sequences after filtering: 106
One-hot encoded tensor shape: torch.Size([106, 56, 4])


## **Initializing models**

In [None]:
# Initializing models
generator = Generator()
discriminator = Discriminator()
g_optimizer = optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
criterion = nn.BCELoss()
cross_entropy = nn.CrossEntropyLoss()

## **Pretrain generator with MLE**

In [None]:
# Pretrain generator with MLE
for epoch in range(G_PRETRAIN_EPOCHS):
    total_loss = 0
    num_batches = 0
    for i in range(0, len(train_data), BATCH_SIZE):
        batch = train_data[i:i+BATCH_SIZE]
        g_optimizer.zero_grad()
        input_seq = batch[:, :-1]  # [batch_size, 56, 4]
        target_seq = batch[:, 1:]  # [batch_size, 56, 4]
        target_indices = torch.argmax(target_seq, dim=-1)  # [batch_size, 56]
        probs, _ = generator(input_seq)  # [batch_size, 56, 4]
        # Clip probabilities to avoid log(0)
        probs = torch.clamp(probs, min=1e-10, max=1.0)
        # Use CrossEntropyLoss
        loss = cross_entropy(probs.view(-1, VOCAB_SIZE), target_indices.view(-1))
        loss.backward()
        # Clip gradients to prevent explosion
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        g_optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    avg_loss = total_loss / num_batches if num_batches > 0 else float('nan')
    print(f'Generator Pretrain Epoch {epoch+1}, Loss: {avg_loss}')

Generator Pretrain Epoch 1, Loss: 1.3873674074808757
Generator Pretrain Epoch 2, Loss: 1.3863003651301067
Generator Pretrain Epoch 3, Loss: 1.3852418263753254
Generator Pretrain Epoch 4, Loss: 1.3840083678563435
Generator Pretrain Epoch 5, Loss: 1.382477839787801
Generator Pretrain Epoch 6, Loss: 1.3823171059290569
Generator Pretrain Epoch 7, Loss: 1.3816022078196208
Generator Pretrain Epoch 8, Loss: 1.3815967241923015
Generator Pretrain Epoch 9, Loss: 1.3816067377726238
Generator Pretrain Epoch 10, Loss: 1.381461461385091
Generator Pretrain Epoch 11, Loss: 1.3812662760416667
Generator Pretrain Epoch 12, Loss: 1.3811468680699666
Generator Pretrain Epoch 13, Loss: 1.3810993830362956
Generator Pretrain Epoch 14, Loss: 1.380990982055664
Generator Pretrain Epoch 15, Loss: 1.3808532158533733
Generator Pretrain Epoch 16, Loss: 1.3807512521743774
Generator Pretrain Epoch 17, Loss: 1.380648096402486
Generator Pretrain Epoch 18, Loss: 1.3805132706960042
Generator Pretrain Epoch 19, Loss: 1.3803

## **Pretrain Discriminator**

In [None]:
# Pretrain discriminator
for epoch in range(D_PRETRAIN_EPOCHS):
    for i in range(0, len(train_data), BATCH_SIZE):
        batch = train_data[i:i+BATCH_SIZE]
        d_optimizer.zero_grad()
        # Real sequences
        real_labels = torch.ones(batch.size(0), 1)
        real_preds = discriminator(batch)
        real_loss = criterion(real_preds, real_labels)
        # Generate fake sequences
        noise = torch.randint(0, VOCAB_SIZE, (batch.size(0), SEQ_LENGTH)).long()
        fake_seq = torch.zeros(batch.size(0), SEQ_LENGTH, VOCAB_SIZE)
        fake_seq.scatter_(2, noise.unsqueeze(-1), 1)
        fake_labels = torch.zeros(batch.size(0), 1)
        fake_preds = discriminator(fake_seq)
        fake_loss = criterion(fake_preds, fake_labels)
        loss = (real_loss + fake_loss) / 2
        loss.backward()
        d_optimizer.step()
    print(f'Discriminator Pretrain Epoch {epoch+1}, Loss: {loss.item()}')

Discriminator Pretrain Epoch 1, Loss: 0.6892892122268677
Discriminator Pretrain Epoch 2, Loss: 0.6638418436050415
Discriminator Pretrain Epoch 3, Loss: 0.6594439744949341
Discriminator Pretrain Epoch 4, Loss: 0.5913381576538086
Discriminator Pretrain Epoch 5, Loss: 0.5629907846450806
Discriminator Pretrain Epoch 6, Loss: 0.5201982259750366
Discriminator Pretrain Epoch 7, Loss: 0.4893595576286316
Discriminator Pretrain Epoch 8, Loss: 0.4469603896141052
Discriminator Pretrain Epoch 9, Loss: 0.47076040506362915
Discriminator Pretrain Epoch 10, Loss: 0.374267041683197
Discriminator Pretrain Epoch 11, Loss: 0.3711191415786743
Discriminator Pretrain Epoch 12, Loss: 0.3366257846355438
Discriminator Pretrain Epoch 13, Loss: 0.3141634464263916
Discriminator Pretrain Epoch 14, Loss: 0.3903343081474304
Discriminator Pretrain Epoch 15, Loss: 0.25405827164649963
Discriminator Pretrain Epoch 16, Loss: 0.26806288957595825
Discriminator Pretrain Epoch 17, Loss: 0.253582626581192
Discriminator Pretrain

## **Adversial Training**

In [None]:
# Adversarial Training
for epoch in range(NUM_EPOCHS):
    for i in range(0, len(train_data), BATCH_SIZE):
        batch = train_data[i:i+BATCH_SIZE]
        # Train discriminator
        d_optimizer.zero_grad()
        real_labels = torch.ones(batch.size(0), 1)
        real_preds = discriminator(batch)
        real_loss = criterion(real_preds, real_labels)
        # Generate sequences
        noise = torch.randint(0, VOCAB_SIZE, (batch.size(0), SEQ_LENGTH)).long()
        fake_seq = torch.zeros(batch.size(0), SEQ_LENGTH, VOCAB_SIZE)
        fake_seq.scatter_(2, noise.unsqueeze(-1), 1)
        probs, _ = generator(fake_seq)
        dist = Categorical(probs)
        gen_seq = dist.sample()
        fake_seq = torch.zeros(batch.size(0), SEQ_LENGTH, VOCAB_SIZE)
        fake_seq.scatter_(2, gen_seq.unsqueeze(-1), 1)
        fake_labels = torch.zeros(batch.size(0), 1)
        fake_preds = discriminator(fake_seq)
        fake_loss = criterion(fake_preds, fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        d_optimizer.step()

        # Train generator with REINFORCE
        g_optimizer.zero_grad()
        rewards = rollout(generator, gen_seq)
        probs, _ = generator(fake_seq)
        dist = Categorical(probs)
        log_probs = dist.log_prob(gen_seq)
        g_loss = -torch.mean(log_probs * rewards)
        g_loss.backward()
        g_optimizer.step()
    print(f'Adversarial Epoch {epoch+1}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

Adversarial Epoch 1, D Loss: 0.12251493334770203, G Loss: 0.10263233631849289
Adversarial Epoch 2, D Loss: 0.13568075001239777, G Loss: 0.10845666378736496
Adversarial Epoch 3, D Loss: 0.04440082609653473, G Loss: 0.013211316429078579
Adversarial Epoch 4, D Loss: 0.04993482306599617, G Loss: 0.12127697467803955
Adversarial Epoch 5, D Loss: 0.042251456528902054, G Loss: 0.015261474996805191
Adversarial Epoch 6, D Loss: 0.047793373465538025, G Loss: 0.008318142034113407
Adversarial Epoch 7, D Loss: 0.01615300215780735, G Loss: 0.04711707681417465
Adversarial Epoch 8, D Loss: 0.005504402332007885, G Loss: 0.08652516454458237
Adversarial Epoch 9, D Loss: 0.00948171317577362, G Loss: 0.011658132076263428
Adversarial Epoch 10, D Loss: 0.03556269407272339, G Loss: 0.0048830704763531685
Adversarial Epoch 11, D Loss: 0.009837744757533073, G Loss: 0.030247816815972328
Adversarial Epoch 12, D Loss: 0.024226395413279533, G Loss: 0.020878203213214874
Adversarial Epoch 13, D Loss: 0.0097550181671977

## **Generate sample sequences**

In [None]:
# Generate sample sequences
def generate_sequences(generator, num_samples=10):
    generator.eval()
    sequences = []
    with torch.no_grad():
        noise = torch.randint(0, VOCAB_SIZE, (num_samples, SEQ_LENGTH)).long()
        seq = torch.zeros(num_samples, SEQ_LENGTH, VOCAB_SIZE)
        seq.scatter_(2, noise.unsqueeze(-1), 1)
        probs, _ = generator(seq)
        dist = Categorical(probs)
        gen_seq = dist.sample()
        for seq in gen_seq:
            seq_str = ''.join([INV_NUCLEOTIDE_MAP[idx.item()] for idx in seq])
            sequences.append(seq_str)
    return sequences

## **Generate and print sequences**

In [None]:
# Generate and print sequences
generated = generate_sequences(generator)
for seq in generated:
    print(f'Generated Sequence: {seq}')

Generated Sequence: TAATTAGTGCTGTATAACGTCTGAGAAATGATACGTGGCTAAGATGGATAGAAGGT
Generated Sequence: CAATCTTGGAAAATGGAACAGGAAAAGGAGTTTAGTCATTAATTTAACAAGACAAA
Generated Sequence: CAAAGGATTGGTTTAGTGTCAGGATACTGGTAAATAAAGTGAGATACACTAAGGGA
Generated Sequence: GGAGATAAGATTTATGTGGTGCGAATAGATATGAGGTGAGAAGAGAGGAGAGTTCA
Generated Sequence: CCAGCAAATGCTTCGATTGGTAACAGGGTAAGTTACTTCTAAAGGCTGAAAAGTGA
Generated Sequence: TCAAGCAAGAGATATGCAATTTCTGTTGATAGAGAATGGTGGACCGTGGTTTTGTT
Generated Sequence: GATAGTAATAAAGTTAAAGGATTAGGGGAACTAAGTGTTTGTCGATAACGGGTATA
Generated Sequence: TCGAAGAATGAGAAGCAACTGGATTATTAGGCATCACAATAAATAGTGGAAGGAGA
Generated Sequence: GTAATCGGGGGCCTTAGCCGATCGGAAAATAAAGTGGTACATAATTAATATCATAA
Generated Sequence: GCCTAAAAGGTCAGATTTGGTAGAGGGAGTTTATATTTGCCACGAATGAATGTGTT


In [None]:
torch.save(generator, 'generator.pth')

In [None]:
torch.save(discriminator, 'discriminator.pth')

In [None]:
torch.save(g_optimizer, 'g_optimizer.pth')

In [None]:
torch.save(g_optimizer, 'd_optimizer.pth')

In [None]:
torch.save(generator.state_dict, 'generator_state_dict.pth')
torch.save(discriminator.state_dict, 'discriminator_state_dict.pth')
torch.save(g_optimizer.state_dict, 'g_optimizer_state_dict.pth')
torch.save(d_optimizer.state_dict, 'd_optimizer_state_dict.pth')

**------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------**