In [33]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from torch.nn.utils.rnn import pad_sequence
import random
from torch.utils.data import random_split
# Step 1: Load the data
lineage_file = "./dataset1.csv"
text_files_folder = "./text_files"

# Step 2: Load all the text files to converted_df
converted_data = []
for file_path in os.listdir(text_files_folder):
    file_name = os.path.splitext(file_path)[0]
    with open(os.path.join(text_files_folder, file_path), "r") as file:
        content = file.read().strip()
        converted_data.append(['"' + file_name + '"', content])

converted_df = pd.DataFrame(converted_data, columns=["ID", "Enc_Sequence"])
max_length = converted_df["Enc_Sequence"].str.len().max()
####################################################################################

# Step 3: Load lineage file as DataFrame
lineage_df = pd.read_csv(lineage_file)
print(lineage_df)
merged_df = pd.merge(converted_df, lineage_df, on="ID", how="left")
converted_df["lineage"] = merged_df["lineage"]

print(converted_df)


lineages = lineage_df["lineage"].unique()
num_classes = len(lineages)
lineage_to_id = {lineage: i for i, lineage in enumerate(lineages)}

#####################################################################################

# Step 4: Prepare the data for training
class ViralSequencesDataset(Dataset):
    
    def __init__(self, converted_df, max_length, lineage_to_id):
        self.data = converted_df
        self.max_length = max_length
        self.lineage_to_id = lineage_to_id
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sequence = self.data.iloc[index]["Enc_Sequence"]
        enc_sequence = np.array(list(sequence), dtype=np.uint8)
        enc_sequence = enc_sequence.reshape(-1)  # Remove extra dimensions

        # Pad the sequence with zeros
        padded_sequence = nn.functional.pad(torch.tensor(enc_sequence), pad=(0, self.max_length - len(enc_sequence)))

        lineage = self.data.iloc[index]["lineage"]
        lineage_id = self.lineage_to_id[lineage]

        return padded_sequence, lineage_id, len(enc_sequence)  # Return sequence length
    
    def __str__(self):
        dataset_info = f"ViralSequencesDataset\nDataset size: {len(self)}\nMax length: {self.max_length}"
        data_info = f"Data:\n{self.data.head()}"
        return f"{dataset_info}\n\n{data_info}"


# dataset = ViralSequencesDataset(output_file, max_length)
dataset = ViralSequencesDataset(converted_df, max_length, lineage_to_id)
print(dataset)
# Set the batch size and split ratio
batch_size = 64
train_ratio = 0.8

train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


# Create data loaders for training and testing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
#####################################################################################

# Step 5: Define the model architecture
class LineageClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LineageClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, num_classes)

    def forward(self, x, length):
        out = self.fc1(x.float())
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.fc4(out)
        return out
#####################################################################################

# Step 6: Train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Set the sequence length
sequence_length = max_length

# Set the number of output classes
num_classes = 10  # Change this to the appropriate number of classes

Hidden_size=512

# Initialize the model
model = LineageClassifier(input_size=sequence_length, hidden_size=Hidden_size, num_classes=num_classes)
model = model.to(device)

# Set the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Set the number of epochs
num_epochs = 5

# Initialize variables for early stopping
best_loss = float('inf')
patience = 3
counter = 0

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = 0

    for batch_idx, (inputs, targets, lengths) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs, lengths)  # Pass lengths to the model

        targets = targets.view(-1)
        outputs = outputs.view(-1, num_classes)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        correct_predictions += (predicted == targets).sum().item()

        if batch_idx % 10 == 9:
            batch_loss = running_loss / 10
            accuracy = 100 * correct_predictions / (10 * len(inputs))
            print(f"Epoch [{epoch+1}/{num_epochs}] - Batch [{batch_idx+1}/{len(train_loader)}] - Loss: {batch_loss:.4f} - Accuracy: {accuracy:.2f}%")
            running_loss = 0.0
            correct_predictions = 0

#####################################################################################

# Step 7: Save the trained model
torch.save(model.state_dict(), "./lineage_classifier.pth")


              ID    lineage
0     "ON414702"  "BA.2.31"
1     "ON056475"  "BA.2.31"
2     "ON203732"  "BA.2.31"
3     "OW275801"  "BA.2.31"
4     "ON914730"  "BA.2.31"
...          ...        ...
4995  "OM122217"  "BA.1.17"
4996  "OV662626"  "BA.1.17"
4997  "OM200680"  "BA.1.17"
4998  "OW742696"  "BA.1.17"
4999  "OW120613"  "BA.1.17"

[5000 rows x 2 columns]
              ID                                       Enc_Sequence    lineage
0     "FR990446"  0000010000000101000001010000011111110110001101...  "B.1.258"
1     "FR991388"  1010111111001100010111110101010010101100000100...  "B.1.258"
2     "FR991439"  0000010100000111111101100011011101111110110010...  "B.1.258"
3     "HG994296"  0000000000000000000000000000000000000001010010...  "B.1.258"
4     "HG994312"  0000000000000000000000000000000000000000000000...  "B.1.258"
...          ...                                                ...        ...
4995  "OX273314"  0000000000000000000000000000000000000000000000...   "BA.5.1"
4996  "

In [34]:

#####################
# WORKING TEST
#####################
# Step 8: Test the model
def test_model(model, test_loader, device):
    model.eval()
    total_samples = 0
    correct_predictions = 0

    with torch.no_grad():
        for inputs, targets, lengths in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs, lengths)
            _, predicted = torch.max(outputs.data, 1)

            total_samples += targets.size(0)
            correct_predictions += (predicted == targets).sum().item()

    accuracy = 100 * correct_predictions / total_samples
    print(f"Test Accuracy: {accuracy:.2f}%")

# Call the test function
test_model(model, test_loader, device)

Test Accuracy: 76.30%


In [16]:

#####################
# NOT WORKING RIGHT NOW
#####################
# # Load the saved model state dict
# model = LineageClassifier(input_size=sequence_length, hidden_size=Hidden_size, num_classes=num_classes)
# model.load_state_dict(torch.load("./lineage_classifier.pth"))
# model = model.to(device)

# # Set the model to evaluation mode
# model.eval()

# # Create the test DataLoader
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# # Define evaluation metrics
# test_loss = 0.0
# correct_predictions = 0
# total_predictions = 0

# # Perform model evaluation
# with torch.no_grad():
#     for batch_idx, inputs in enumerate(test_loader):
#         print(inputs)
#         inputs = torch.tensor(inputs).to(device)
#         targets = torch.ones(inputs.size(0)).long().to(device)

#         # Forward pass
#         outputs = model(inputs)

#         # Compute the loss
#         loss = criterion(outputs, targets)
#         test_loss += loss.item()

#         # Calculate accuracy
#         _, predicted = torch.max(outputs.data, 1)
#         total_predictions += targets.size(0)
#         correct_predictions += (predicted == targets).sum().item()

# # Calculate average test loss and accuracy
# avg_test_loss = test_loss / len(test_loader)
# test_accuracy = 100.0 * correct_predictions / total_predictions

# # Print the test results
# print("Test Loss: {:.4f}".format(avg_test_loss))
# print("Test Accuracy: {:.2f}%".format(test_accuracy))


[tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.uint8), tensor([1, 1, 6, 7, 9, 2, 0, 5, 6, 0, 6, 0, 1, 1, 3, 2, 4, 6, 4, 1, 6, 6, 0, 6,
        9, 8, 7, 3, 9, 4, 5, 6, 9, 2, 6, 3, 1, 3, 8, 9, 2, 9, 7, 9, 8, 4, 3, 1,
        0, 5, 9, 8, 2, 1, 6, 3, 0, 9, 0, 4, 2, 0, 0, 9]), tensor([59780, 59780, 59740, 59416, 59654, 59416, 59326, 59538, 59688, 58596,
        59436, 59620, 59658, 59780, 59780, 59624, 59690, 59436, 58094, 59598,
        59688, 59688, 59752, 59436, 59746, 59794, 59792, 59780, 59540, 59638,
        58980, 59688, 59734, 59564, 59688, 59780, 59768, 59780, 59770, 59746,
        59690, 59242, 59554, 59740, 59794, 59262, 59780, 59592, 59700, 59534,
        59638, 59798, 59700, 59540, 57092, 59292, 59370, 59554, 59448, 59062,
        59734, 59448, 59448, 59746])]


TypeError: only integer tensors of a single element can be converted to an index