<a href="https://colab.research.google.com/github/shubhamsnehil07/Test-Repository/blob/main/DNA_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd

In [5]:
df = pd.read_csv("data.csv", delimiter="\t", header=None)

In [6]:
df.head()

Unnamed: 0,0,1,2
0,"+,S10,",,tactagcaatacgcttgcgttcggtggttaagtatgtataatgcgc...
1,"+,AMPC,",,tgctatcctgacagttgtcacgctgattggtgtcgttacaatctaa...
2,"+,AROH,",,gtactagagaactagtgcattagcttatttttttgttatcatgcta...
3,"+,DEOP2,",aattgtgatgtgtatcgaagtgtgttgcggagtagatgttagaata...,
4,"+,LEU1_TRNA,",tcgataattaactattgacgaaaagctgaaaaccactagaatgcgc...,


In [7]:
# Extract the column that contains DNA (likely column 2 or 3)
sequences = df[2].dropna().str.upper().tolist()

# Filter sequences with only A/T/C/G
clean_sequences = [s for s in sequences if set(s).issubset(set("ATCG"))]


In [101]:
len(clean_sequences[1])

57

## 1. Prepare and Encode DNA Data

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define vocab
vocab = {'A': 0, 'T': 1, 'C': 2, 'G': 3}
inv_vocab = {v: k for k, v in vocab.items()}
max_len = 57

def encode_sequence(seq, max_len):
    encoded = [vocab[char] for char in seq if char in vocab]
    return encoded[:max_len] + [0] * (max_len - len(encoded))

class DNADataset(Dataset):
    def __init__(self, sequences):
        self.sequences = [torch.tensor(encode_sequence(seq, max_len), dtype=torch.long) for seq in sequences]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]

dataset = DNADataset(clean_sequences)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


## 2. Define Generator and Discriminator

In [17]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, vocab_size=4, hidden_dim=64, seq_len=50):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.seq_len = seq_len

    def forward(self, x, hidden=None):
        x = self.emb(x)
        out, hidden = self.rnn(x, hidden)
        logits = self.fc(out)
        return logits, hidden

    def sample(self, batch_size):
        samples = []
        input = torch.zeros((batch_size, 1), dtype=torch.long)
        hidden = None
        for _ in range(self.seq_len):
            logits, hidden = self.forward(input, hidden)
            probs = torch.softmax(logits[:, -1, :], dim=-1)
            input = torch.multinomial(probs, num_samples=1)
            samples.append(input)
        return torch.cat(samples, dim=1)

class Discriminator(nn.Module):
    def __init__(self, vocab_size=4, hidden_dim=64):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, hidden_dim)
        self.rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = self.emb(x)
        _, h = self.rnn(x)
        return torch.sigmoid(self.fc(h.squeeze(0)))


## 3. Rollout Module for Reward Estimation

In [14]:
class Rollout:
    def __init__(self, generator, update_rate=0.9):
        self.generator = generator
        self.update_rate = update_rate

    def rollout(self, partial_seq):
        x = partial_seq
        hidden = None
        for _ in range(x.size(1), self.generator.seq_len):
            logits, hidden = self.generator(x, hidden)
            probs = torch.softmax(logits[:, -1, :], dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat([x, next_token], dim=1)
        return x

    def get_reward(self, x, discriminator, rollout_num=8):
        batch_size, seq_len = x.size()
        rewards = torch.zeros((batch_size, seq_len))
        for t in range(1, seq_len + 1):
            samples = []
            for _ in range(rollout_num):
                sample = self.rollout(x[:, :t])
                samples.append(sample)
            samples = torch.stack(samples).view(-1, seq_len)
            with torch.no_grad():
                preds = discriminator(samples).view(rollout_num, batch_size)
            rewards[:, t - 1] = preds.mean(0)
        return rewards


## 4. Training Functions

In [15]:
def train_generator_PG(generator, rollout, discriminator, optimizer, batch_size):
    fake_samples = generator.sample(batch_size)
    rewards = rollout.get_reward(fake_samples, discriminator)

    optimizer.zero_grad()
    logits, _ = generator(fake_samples)
    log_probs = torch.log_softmax(logits, dim=-1)

    loss = 0
    for t in range(generator.seq_len):
        log_prob = log_probs[:, t, :].gather(1, fake_samples[:, t].unsqueeze(1)).squeeze()
        loss += -log_prob * rewards[:, t]
    loss = loss.mean()
    loss.backward()
    optimizer.step()
    return loss.item()

def train_discriminator(discriminator, real_data, generator, optimizer):
    fake_data = generator.sample(real_data.size(0)).detach()
    real_labels = torch.ones(real_data.size(0), 1)
    fake_labels = torch.zeros(real_data.size(0), 1)

    optimizer.zero_grad()
    real_loss = nn.BCELoss()(discriminator(real_data), real_labels)
    fake_loss = nn.BCELoss()(discriminator(fake_data), fake_labels)
    loss = real_loss + fake_loss
    loss.backward()
    optimizer.step()
    return loss.item()


## 5. Training Loop

In [22]:
generator = Generator()
discriminator = Discriminator()
rollout = Rollout(generator)

g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-3)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3)

for epoch in range(100):
    for real_batch in dataloader:
        d_loss = train_discriminator(discriminator, real_batch, generator, d_optimizer)
        g_loss = train_generator_PG(generator, rollout, discriminator, g_optimizer, real_batch.size(0))

    print(f"[Epoch {epoch}] G Loss: {g_loss:.4f}, D Loss: {d_loss:.4f}")


[Epoch 0] G Loss: 31.0750, D Loss: 1.0658
[Epoch 1] G Loss: 24.9204, D Loss: 1.1442
[Epoch 2] G Loss: 18.4034, D Loss: 0.6139
[Epoch 3] G Loss: 13.8152, D Loss: 0.8339
[Epoch 4] G Loss: 8.4554, D Loss: 0.7083
[Epoch 5] G Loss: 6.3021, D Loss: 0.3417
[Epoch 6] G Loss: 4.6552, D Loss: 0.3445
[Epoch 7] G Loss: 2.5529, D Loss: 0.2649
[Epoch 8] G Loss: 1.7606, D Loss: 0.2702
[Epoch 9] G Loss: 2.0057, D Loss: 0.1509
[Epoch 10] G Loss: 1.1085, D Loss: 0.1293
[Epoch 11] G Loss: 1.5090, D Loss: 0.1502
[Epoch 12] G Loss: 1.3309, D Loss: 0.1335
[Epoch 13] G Loss: 1.5051, D Loss: 0.2098
[Epoch 14] G Loss: 0.9225, D Loss: 0.1886
[Epoch 15] G Loss: 2.0198, D Loss: 0.1597
[Epoch 16] G Loss: 1.3998, D Loss: 0.1340
[Epoch 17] G Loss: 0.8872, D Loss: 2.6478
[Epoch 18] G Loss: 1.3954, D Loss: 0.2509
[Epoch 19] G Loss: 0.9957, D Loss: 2.0993
[Epoch 20] G Loss: 0.5747, D Loss: 1.7160
[Epoch 21] G Loss: 0.2876, D Loss: 1.5054
[Epoch 22] G Loss: 0.9360, D Loss: 1.4290
[Epoch 23] G Loss: 0.2602, D Loss: 1.399

## 6. Generate Synthetic DNA

In [100]:
gen_seq = generator.sample(1)[0]
synth_dna = ''.join([inv_vocab[int(i)] for i in gen_seq])
print("Synthetic DNA:", synth_dna)
print(f'len {len(synth_dna)}')

Synthetic DNA: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
len 50
