# Use the model checkpoint to get the latest model and the misclassified instances

In [26]:
import re
import networkx as nx
from community import community_louvain
from collections import defaultdict
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F


In [40]:

batch_size=1024


# epoch 4, accuracy 0.3099
# epoch 7, accuracy 0.4104
# epoch 12, accuracy 0.5020
# epoch 68, accuracy 0.7192
# epoch 26, accuracy 0.6027
# epoch 95, accuracy 0.7410



# 10k 10k
# epoch 62, accuracy 0.8551
#threshold,sample_size,epoch_to_load = 10000,10000,62


In [46]:




start_month='2020-07'
end_month='2021-05'
data_path=f'data/genomeACTGbases_sample_{threshold}_{sample_size}_{start_month}---{end_month}.csv'
print(data_path)
data=pd.read_csv(data_path)  




class GenomeDataset(Dataset):
    def __init__(self, sequences, labels, base_to_index, label_to_index=None):
        self.sequences = sequences
        self.labels = labels
        self.base_to_index = base_to_index
        self.label_to_index = label_to_index or self._generate_label_to_index()

    def _generate_label_to_index(self):
        unique_labels = sorted(set(self.labels))  # Sort labels to ensure consistency
        return {label: idx for idx, label in enumerate(unique_labels)}

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        encoded_sequence = self.one_hot_encode(sequence)
        # Convert labels to integers if necessary
        label_index = self.label_to_index[label] if isinstance(label, str) else label
        return encoded_sequence, torch.tensor(label_index, dtype=torch.long)

    def one_hot_encode(self, sequence):
        encoded = torch.zeros((len(sequence), len(self.base_to_index)), dtype=torch.float32)
        for i, base in enumerate(sequence):
            if base in self.base_to_index:
                encoded[i, self.base_to_index[base]] = 1
        return encoded


model_name=f'CNNv5_{threshold}_{sample_size}'
checkpoint_path = f'models/MCC_{model_name}_{n_filter}/MCC_{model_name}_epoch{epoch_to_load}.pth'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

print('checkpoint_path',checkpoint_path)
print('model_name',model_name)

    


target = 'region'
bases = {'*', '-', 'A', 'C', 'G', 'T'}
base_to_index = {base: i for i, base in enumerate(sorted(bases))}

X_train, X_temp, y_train, y_temp = train_test_split(data['sequence'], data[target], test_size=0.4, random_state=42)
X_test, X_validation, y_test, y_validation = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

test_dataset = GenomeDataset(X_test.tolist(), y_test.tolist(), base_to_index)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

data/genomeACTGbases_sample_3000_3000_2020-07---2021-05.csv
models/MCC_CNNv5_3000_3000_LOO_England_64/MCC_CNNv5_3000_3000_LOO_England_epoch41.pth
CNNv5_3000_3000_LOO_England


In [47]:


    
    
    
class CNNModelV5(nn.Module):
    def __init__(self, input_channels, n_filter, n_class):
        super(CNNModelV5, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, n_filter, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(n_filter)
        self.conv2 = nn.Conv1d(n_filter, n_filter, kernel_size=4)
        self.bn2 = nn.BatchNorm1d(n_filter)
        self.conv3 = nn.Conv1d(n_filter, n_filter, kernel_size=5)
        self.bn3 = nn.BatchNorm1d(n_filter)
        self.conv4 = nn.Conv1d(n_filter, n_filter, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(n_filter)
        self.conv5 = nn.Conv1d(n_filter, n_filter, kernel_size=3)
        self.bn5 = nn.BatchNorm1d(n_filter)    
        self.conv6 = nn.Conv1d(n_filter, n_filter, kernel_size=3)
        self.bn6 = nn.BatchNorm1d(n_filter) 
        self.maxpool = nn.AdaptiveMaxPool1d(1)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(n_filter, n_class)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))  
        x = F.relu(self.bn5(self.conv5(x))) 
        x = F.relu(self.bn6(self.conv6(x))) 
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x    
      
  
      

n_class=len(list(set(list(data['region']))))
input_channels = 6
n_filter = 64
cnn_model = CNNModelV5(input_channels, n_filter=n_filter, n_class=n_class)


class CNNModel(nn.Module):
    def __init__(self, cnn_model):
        super(CNNModel, self).__init__()
        self.cnn_model = cnn_model

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.cnn_model(x)

        return x


In [48]:
model = CNNModel(cnn_model) 
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


CNNModel(
  (cnn_model): CNNModelV5(
    (conv1): Conv1d(6, 64, kernel_size=(3,), stride=(1,))
    (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(64, 64, kernel_size=(4,), stride=(1,))
    (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv1d(64, 64, kernel_size=(5,), stride=(1,))
    (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv4): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
    (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv5): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
    (bn5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv6): Conv1d(64, 64, kernel_size=(3,), stride=(1,))
    (bn6): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (maxpool): AdaptiveMaxPool1d(output_size=1)
    (dropout): D

In [49]:
'''
criterion = torch.nn.CrossEntropyLoss()

test_loss = 0.0
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        # No need to move inputs and labels to device (GPU or CPU)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

avg_test_loss = test_loss / len(test_loader)
test_accuracy = correct_predictions / total_samples

print(f'Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')
'''

"\ncriterion = torch.nn.CrossEntropyLoss()\n\ntest_loss = 0.0\ncorrect_predictions = 0\ntotal_samples = 0\n\nwith torch.no_grad():\n    for inputs, labels in test_loader:\n        # No need to move inputs and labels to device (GPU or CPU)\n        outputs = model(inputs)\n        loss = criterion(outputs, labels)\n        test_loss += loss.item()\n        _, predicted = torch.max(outputs, 1)\n        correct_predictions += (predicted == labels).sum().item()\n        total_samples += labels.size(0)\n\navg_test_loss = test_loss / len(test_loader)\ntest_accuracy = correct_predictions / total_samples\n\nprint(f'Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')\n"

In [50]:
index_to_label = {v: k for k, v in test_dataset.label_to_index.items()}
index_to_label

{0: 'Australia',
 1: 'Belgium',
 2: 'Canada',
 3: 'Denmark',
 4: 'France',
 5: 'Germany',
 6: 'Iceland',
 7: 'India',
 8: 'Italy',
 9: 'Japan',
 10: 'Luxembourg',
 11: 'Netherlands',
 12: 'Portugal',
 13: 'Scotland',
 14: 'SouthAfrica',
 15: 'Spain',
 16: 'Switzerland',
 17: 'USA',
 18: 'Wales'}

# hard and soft classification

In [51]:
incorrectly_classified = []

total_samples = 0
correct_predictions = 0

# Soft classification
confusion_matrix = torch.zeros(n_class, n_class)

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(probabilities, 1)
        
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
        
        incorrect_indices = (predicted != labels).nonzero(as_tuple=False).squeeze()
        for idx in incorrect_indices:
            incorrectly_classified.append((index_to_label[predicted[idx].item()], index_to_label[labels[idx].item()]))

        for i in range(labels.size(0)):
            confusion_matrix[predicted[i], labels[i]] += 1

prediction_accuracy = correct_predictions / total_samples

print(f'Test set prediction accuracy: {prediction_accuracy:.4f}')


Test set prediction accuracy: 0.7278


In [52]:
'''
incorrectly_classified = []

total_samples = 0
correct_predictions = 0

# Soft classification
confusion_matrix = torch.zeros(n_class, n_class)


with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        
        # Hard classification
        # Convert numerical labels to country names
        predicted_countries = [index_to_label[idx.item()] for idx in predicted]
        true_countries = [index_to_label[idx.item()] for idx in labels]

        total_samples += len(labels)
        correct_predictions += (predicted == labels).sum().item()

        for pred_country, true_country in zip(predicted_countries, true_countries):
            if pred_country != true_country:
                incorrectly_classified.append([pred_country, true_country])
        
        # Soft classification
        probabilities = torch.softmax(outputs, dim=1)  # Get softmax probabilities
        
        for idx, true_label in enumerate(labels):
            true_label = true_label.item()
            for predicted_label, softmax_score in enumerate(probabilities[idx]):
                # Add the softmax score to the corresponding matrix cell
                confusion_matrix[predicted_label][true_label] += softmax_score
  
prediction_accuracy = correct_predictions / total_samples

print(f'Test set prediction accuracy: {prediction_accuracy:.4f}')
'''

"\n# Initialize lists to store incorrectly classified data points\nincorrectly_classified = []\n\n# Initialize variables to track correct predictions\ntotal_samples = 0\ncorrect_predictions = 0\n\n# Soft classification\nconfusion_matrix = torch.zeros(n_class, n_class)\n\n\n# Evaluate the model and gather results\nwith torch.no_grad():\n    for inputs, labels in test_loader:\n        outputs = model(inputs)\n        _, predicted = torch.max(outputs, 1)\n        \n        # Hard classification\n        # Convert numerical labels to country names\n        predicted_countries = [index_to_label[idx.item()] for idx in predicted]\n        true_countries = [index_to_label[idx.item()] for idx in labels]\n        # Update correct predictions count\n        total_samples += len(labels)\n        correct_predictions += (predicted == labels).sum().item()\n        # Gather incorrectly classified data points\n        for pred_country, true_country in zip(predicted_countries, true_countries):\n        

In [53]:
import pandas as pd

misclassified_df = pd.DataFrame(incorrectly_classified, columns=['Actual', 'Predicted'])

misclassification_crosstab = pd.crosstab(index=misclassified_df['Actual'], columns=misclassified_df['Predicted'], rownames=['Actual'], colnames=['Predicted'])

print("Misclassification Crosstab:")
print(misclassification_crosstab)


Misclassification Crosstab:
Predicted    Australia  Belgium  Canada  Denmark  France  Germany  Iceland  \
Actual                                                                       
Australia            0        1       0        0       1        0        0   
Belgium              2        0      46       42      10       32       14   
Canada               5       36       0       20      12       14        3   
Denmark              0       39       2        0      44       10       22   
France               3        6       5        3       0       17        2   
Germany              8        7       4        0      14        0        6   
Iceland              0        3       4        2       0        7        0   
India                6        5      11        0      13       16        2   
Italy                4       10       1        0      24       24        1   
Japan                0        2       1        2       0        1        1   
Luxembourg           0        0     

In [54]:

print(f'crosstable/MCC_{model_name}_accuracy{prediction_accuracy:.4f}_{start_month}---{end_month}.pkl')

crosstable/MCC_CNNv5_3000_3000_LOO_England_accuracy0.7278_2020-07---2021-05.pkl


In [55]:
misclassification_crosstab.to_pickle(f'crosstable/MCC_{model_name}_accuracy{prediction_accuracy:.4f}_{start_month}---{end_month}.pkl')


# Soft misclassification

In [56]:
import pandas as pd
class_labels = list(test_dataset.label_to_index.keys())
soft_misclassified_df = pd.DataFrame(confusion_matrix.numpy(), index=class_labels, columns=class_labels)
soft_misclassified_df.to_pickle(f'crosstable/MCC_softmax_{model_name}_accuracy{prediction_accuracy:.4f}_{start_month}---{end_month}.pkl')