In [22]:
import random
import networkx as nx
import numpy as np
import torch
from torch_geometric.utils.convert import from_networkx
import random
import networkx as nx
from collections import Counter, defaultdict
from torch_geometric.loader import DataLoader
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import src.utils.utils as ut


In [23]:

# Functions for generating sequences and mutations
def generate_random_genome_sequence(length):
    return ''.join(random.choice('ATCG') for _ in range(length))

def mutate_sequence(sequence, num_mutations):
    sequence = list(sequence)
    length = len(sequence)
    for _ in range(num_mutations):
        pos = random.randint(0, length - 1)
        original_base = sequence[pos]
        new_base = random.choice([b for b in 'ATCG' if b != original_base])
        sequence[pos] = new_base
    return ''.join(sequence)

def generate_genome_sequences_with_mutations(n, l, k, num_mutations):
    genome_data = []
    for i in range(n):
        original_sequence = generate_random_genome_sequence(l)
        genome_data.append([original_sequence, i])
        
        for j in range(k):
            mutated_sequence = mutate_sequence(original_sequence, num_mutations)
            genome_data.append([mutated_sequence, i])
    
    return genome_data

# Function to generate k-mers from a sequence
def generate_kmers(sequence, k):
    return [sequence[i:i+k] for i in range(len(sequence) - k + 1)]

def kmer_to_index(kmer):
    """Converts a kmer (string) to an index."""
    base_to_index = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    index = 0
    for char in kmer:
        index = 4 * index + base_to_index[char]
    return index

def subkmer_frequencies_in_kmer(kmer, subkmer_length):
    """Calculates the frequency of each subkmer in a kmer."""
    subkmer_counts = Counter(kmer[i:i+subkmer_length] for i in range(len(kmer) - subkmer_length + 1))
    frequencies = np.zeros(4**subkmer_length)
    for subkmer, count in subkmer_counts.items():
        index = kmer_to_index(subkmer)
        frequencies[index] = count
    return frequencies

In [24]:
# Example usage
num_classes = 5  # number of original sequences
l = 100  # length of each sequence
k = 99
num_mutations = 20
genome_sequences = generate_genome_sequences_with_mutations(num_classes, l, k, num_mutations)

In [25]:
seqs_n = len(genome_sequences)

In [26]:
graphs = []
kmer_len = 4
subkmer_len = 2
num_features = 4**subkmer_len

for seq in genome_sequences:
    G = nx.DiGraph()
    kmers = generate_kmers(seq[0], kmer_len)
    nodes = []
    for kmer in kmers:
        nodes.append((kmer, {"x": torch.as_tensor(subkmer_frequencies_in_kmer(kmer, subkmer_len)/(kmer_len-1), dtype=torch.float32)}))
    G.add_nodes_from(nodes)
    # edges = []
    transition_counts = defaultdict(int)
    for i in range(len(kmers)-1):
        transition_counts[(kmers[i], kmers[i+1])] += 1
    max_count = max(transition_counts.values())
    for key in transition_counts.keys():
        G.add_edge(key[0], key[1], weight=transition_counts[key]/max_count)
    #     edges.append((kmers[i], kmers[i+1]))
    # G.add_edges_from(edges)
    torch_graph = from_networkx(G)
    torch_graph['y'] = torch.tensor([seq[1]])
    graphs.append(torch_graph)


In [27]:
random.shuffle(graphs)
split_value = 0.8
train_dataset = graphs[:int(seqs_n*split_value)]
val_dataset = graphs[int(seqs_n*split_value):]

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [28]:
class GCN(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, edge_weight, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [29]:
def train():
    model.train()
    total_loss = 0
    correct = 0
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.weight, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
    return total_loss / len(train_loader), correct / len(train_dataset)

# Validation function
def validate(loader):
    model.eval()
    correct = 0
    for data in loader:
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.weight, data.batch)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
    return correct / len(loader.dataset)

# Initialize the model, optimizer, and loss function
model = GCN(num_features=num_features, hidden_channels=32, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    train_loss, train_acc = train()
    val_acc = validate(val_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')


Epoch 1, Train Loss: 1.6111, Train Acc: 0.1900, Val Acc: 0.1300
Epoch 2, Train Loss: 1.6000, Train Acc: 0.2125, Val Acc: 0.3600
Epoch 3, Train Loss: 1.5811, Train Acc: 0.2425, Val Acc: 0.4000
Epoch 4, Train Loss: 1.5305, Train Acc: 0.3350, Val Acc: 0.4000
Epoch 5, Train Loss: 1.4432, Train Acc: 0.4075, Val Acc: 0.4800
Epoch 6, Train Loss: 1.2616, Train Acc: 0.4750, Val Acc: 0.6000
Epoch 7, Train Loss: 1.1134, Train Acc: 0.6275, Val Acc: 0.5700
Epoch 8, Train Loss: 0.9425, Train Acc: 0.6450, Val Acc: 0.6500
Epoch 9, Train Loss: 0.8764, Train Acc: 0.6150, Val Acc: 0.7000
Epoch 10, Train Loss: 0.7994, Train Acc: 0.7100, Val Acc: 0.8100
Epoch 11, Train Loss: 0.6349, Train Acc: 0.7350, Val Acc: 0.7300
Epoch 12, Train Loss: 0.5668, Train Acc: 0.7300, Val Acc: 0.8100
Epoch 13, Train Loss: 0.5513, Train Acc: 0.7550, Val Acc: 0.8000
Epoch 14, Train Loss: 0.5280, Train Acc: 0.7225, Val Acc: 0.7700
Epoch 15, Train Loss: 0.4632, Train Acc: 0.7600, Val Acc: 0.8600
Epoch 16, Train Loss: 0.4363, Trai