In [1]:
# Importing modules
import pandas as pd
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

In [2]:
# Custom dataset
class datasets:
    
    def __init__(self, data):
        self.items = data.to_numpy()
        self.length = seq_data.shape[0]

    @staticmethod
    def OneHotEncoder(seqArray):
        seq_lis = list(seqArray)
        mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N':4}
        int_seq = [mapping[base] for base in seq_lis]
        encoded_seq = (np.eye(len(mapping))[int_seq]).flatten()
        return encoded_seq
        
    def __getitem__(self, index):
        seq, label = self.items[index]
        seq = self.OneHotEncoder(seq)
        seq_tensor, label_tensor = torch.tensor(seq).float(), torch.tensor(label)
        return seq_tensor, label_tensor

    def __len__(self):
        return self.length

In [3]:
# AutoEncoder Model
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Bottleneck layer
        self.hidden = 100

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(755, 256),
            nn.ReLU(),
            nn.Linear(256, self.hidden),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(self.hidden, 256),
            nn.ReLU(),
            nn.Linear(256, 755),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded_x = self.encoder(x)
        decoded_x = self.decoder(encoded_x)
        return encoded_x, decoded_x

In [4]:
# Data
seq_data = pd.read_csv("sequence_data.csv")

seq_data

Unnamed: 0,sequence,class
0,AATGTACAGTATTGCGTTTTGGAAAGAGTCTGGATTTTTAGGGCTC...,1
1,AATGTACAGTATTGCGTTTTGCCTCCACCTCATTCCAGGCCTAAGA...,1
2,CACCGGTGGGAGATTGGAGTCCTAGCCCGACTCGCCGGGCAGAGCG...,1
3,AATGTACAGTATTGCGTTTTGCTGTGCCAGGGACCTTACCTTATAC...,1
4,TTAATCGCGTCCATTGAAGTCCTCTACCGTGCAGCTCATCACGCAG...,1
...,...,...
99995,AATGTACAGTATTGCGTTTTGGACTGTTTGGGAGTTGATGACCTTT...,1
99996,AATGTACAGTATTGCGTTTTGCCAGCTGCTCAGGAGTCATGCTTAG...,1
99997,ATTGTGAATTAAATTGGAGTCCTGCCATCGGAACTGCTGTCTGCAT...,1
99998,AATGTACAGTATTGCGTTTTGAGAGTAAAGTAGATGATGGAAATAT...,1


In [5]:
# Dataset
train_datasets = datasets(seq_data)

# Dataloader
train_dataloader = DataLoader(train_datasets, batch_size=32)

# Device
device = torch.device("cpu")

# Model 
model = AutoEncoder()
model = model.to(device)

# Optimizer 
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Loss function
criterion = nn.MSELoss()

In [6]:
# Training loop
for epoch in range(10):
    final_loss = 0
    model.train()
    for batch in train_dataloader:
        optimizer.zero_grad()
        X, _ = batch
        X = X.to(device)
        X_encoded, X_decoded = model(X)
        loss = criterion(X_decoded, X)
        loss.backward()
        optimizer.step()
        final_loss += loss.item()
    print(f"Epoch: {epoch+1} | Loss: {final_loss/len(train_dataloader)}")

Epoch: 1 | Loss: 0.07601029044032097
Epoch: 2 | Loss: 0.06104164396047592
Epoch: 3 | Loss: 0.05885224168896675
Epoch: 4 | Loss: 0.057776012491583824
Epoch: 5 | Loss: 0.057172994408607486
Epoch: 6 | Loss: 0.05672843850195408
Epoch: 7 | Loss: 0.05638966010928154
Epoch: 8 | Loss: 0.05608514138042927
Epoch: 9 | Loss: 0.055840651872754096
Epoch: 10 | Loss: 0.05559725811779499


In [7]:
# Function for converting tensor to sequence
def tensor2seq(tensor):
    tensor = tensor.int()
    tensor = tensor.reshape(151, 5).numpy()
    mapping = {0: 'A', 1: 'C', 2: 'G', 3: 'T', 4: 'N'}
    int_seq = np.argmax(tensor, axis=1)
    seq = ''.join([mapping[i] for i in int_seq])
    return seq

# Test data
test_data = train_datasets[0][0].float()

# Prediction
model.eval()
encoded, decoded = model(test_data)

# Print original and generated sequence
print(f"Original Sequence: {tensor2seq(test_data)}\n\nGenerated Sequence: {tensor2seq(decoded)}")

Original Sequence: AATGTACAGTATTGCGTTTTGGAAAGAGTCTGGATTTTTAGGGCTCATACTATCCTCCGTGGTCATGCTCCAATAAATTCACTGCTTTGTGGCGCGACCCTTAGGTATTCTGCATTTTCAGCTGTGGAGCCCTTAAAGATGCCATTTGGCT

Generated Sequence: AATGTACAGTAATGCGTTTTGGAAAGAGTCTAGAATATAAGGGCTCATACAATCCAAAGAAAACAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
