In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import torch
from sklearn.model_selection import train_test_split
from collections import defaultdict
import data.vectorisation as vectorisation 
import data.data_setup as data_setup
import model.transformer as transformer

In [None]:
# Load the cleaned protein sequence dataset
df = pd.read_csv('2018-06-06-ss.cleaned.csv')
maxlen_seq = 192
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the tokenised preprocessed input data
input_data, target_data_sst3, target_data_sst8, n_words, n_tags_sst3, n_tags_sst8, \
    input_seqs, target_sst3_seqs, target_sst8_seqs, tokenizer_decoder_sst3, tokenizer_decoder_sst8 = vectorisation.tokenisation(df, maxlen_seq, device)



In [None]:
# Here we will visualize the accuracy of our model predictions based on a subset of 100 test data points

# Split the input and target data into training and test sets for SST3 and SST8
X_train, X_test, y_train_sst3, y_test_sst3 = train_test_split(input_data, target_data_sst3, test_size=0.25, random_state=0)
_, _, y_train_sst8, y_test_sst8 = train_test_split(input_data, target_data_sst8, test_size=0.25, random_state=0)

# Move data to the device (GPU if available)
X_train, X_test = X_train.clone().detach().to(device), X_test.clone().detach().to(device)

# Split sequences and targets for SST3 and SST8, retaining 25% for testing
seq_train, seq_test, target_train_sst3, target_test_sst3, target_train_sst8, target_test_sst8 = train_test_split(
    input_seqs, target_sst3_seqs, target_sst8_seqs, test_size=0.25, random_state=0)

# Converts one-hot encoded sequences back to the original sequence using an index mapping
def onehot_to_seq(oh_seq, index):
    s = ''
    for o in oh_seq:
        i = np.argmax(o)
        if i != 0:
            #s += index[i]
            s += index.get(i, '-')
        else:
            break
    return s

# Calculates Q3 accuracy by comparing predicted and true sequences character-by-character
def q3_acc_strings(y_true_str, y_pred_str):
    min_length = min(len(y_true_str), len(y_pred_str))
    correct_matches = sum(1 for i in range(min_length) if y_true_str[i] == y_pred_str[i])
    return correct_matches / min_length if min_length > 0 else 0.0

# Function to display prediction results and calculate Q3 accuracy for a given sequence
def plot_results(x, y, y_,reverse_decoder_index, seq_type):
    print("---")
    print(f"Input: {str(x)}")
    print(f"Target ({seq_type}): " + str(y))
    print(f"Result ({seq_type}): " + str(onehot_to_seq(y_, reverse_decoder_index).upper()))
    print(f"Q3 Accuracy: {q3_acc_strings(y, str(onehot_to_seq(y_, reverse_decoder_index).upper()))}")

# Reverse mapping from token indices to characters for decoding predicted sequences
reverse_decoder_index_sst3 = {value: key for key, value in tokenizer_decoder_sst3.word_index.items()}
reverse_decoder_index_sst8 = {value: key for key, value in tokenizer_decoder_sst8.word_index.items()}

# Load model
embed_dim = 512             # Embedding dimension for token embeddings
num_heads = 16              # Number of attention heads in transformer
ff_dim = 2048               # Dimension of feedforward network within transformer
dropout = 0.2               # Dropout rate for regularization
num_encoder_layers=6        # Number of stacked transformer encoder layers

model = transformer.TransformerModel(n_words, n_tags_sst3, n_tags_sst8, embed_dim, 
                                     num_heads, ff_dim, maxlen_seq, num_encoder_layers, dropout)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Perform predictions on a subset of the training and test data
N= 100 #cannot exceed 10000 this "CUDA out of memory. Tried to allocate 6.50 GiB. GPU"
X_train = X_train[:N].to(device)
X_test = X_test[:N].to(device)

# Generate predictions for SST3 and SST8 on training and testing data
with torch.no_grad():
    y_train_pred_sst3, y_train_pred_sst8 = model(X_train)
    y_test_pred_sst3, y_test_pred_sst8 = model(X_test)

# Adjust output dimensions for processing (rearrange dimensions for interpretation)
y_train_pred_sst3=y_train_pred_sst3.permute(1,0,2)
y_train_pred_sst8=y_train_pred_sst8.permute(1,0,2)
y_test_pred_sst3=y_test_pred_sst3.permute(1,0,2)
y_test_pred_sst8=y_test_pred_sst8.permute(1,0,2)

# Initialize accuracy counters
q3_accuracy_train_sst3 = 0
q3_accuracy_train_sst8 = 0
q3_accuracy_test_sst3 = 0
q3_accuracy_test_sst8 = 0

# Calculate and display Q3 accuracy for each sequence in training and testing for SST3 and SST8
print('Training SST3')
for i in range(N):
    q3_acc = q3_acc_strings(target_train_sst3[i], onehot_to_seq(y_train_pred_sst3[i].cpu().numpy(), reverse_decoder_index_sst3).upper())
    q3_accuracy_train_sst3 += q3_acc
    plot_results(seq_train[i], target_train_sst3[i], y_train_pred_sst3[i].cpu().numpy(), reverse_decoder_index_sst3, 'SST3')

print('Training SST8')
for i in range(N):
    q3_acc = q3_acc_strings(target_train_sst8[i], onehot_to_seq(y_train_pred_sst8[i].cpu().numpy(), reverse_decoder_index_sst8).upper())
    q3_accuracy_train_sst8 += q3_acc
    plot_results(seq_train[i], target_train_sst8[i], y_train_pred_sst8[i].cpu().numpy(), reverse_decoder_index_sst8, 'SST8')

print('Testing SST3')
for i in range(N):
    q3_acc = q3_acc_strings(target_test_sst3[i], onehot_to_seq(y_test_pred_sst3[i].cpu().numpy(), reverse_decoder_index_sst3).upper())
    q3_accuracy_test_sst3 += q3_acc
    plot_results(seq_test[i], target_test_sst3[i], y_test_pred_sst3[i].cpu().numpy(), reverse_decoder_index_sst3, 'SST3')

print('Testing SST8')
for i in range(N):
    q3_acc = q3_acc_strings(target_test_sst8[i], onehot_to_seq(y_test_pred_sst8[i].cpu().numpy(), reverse_decoder_index_sst8).upper())
    q3_accuracy_test_sst8 += q3_acc
    plot_results(seq_test[i], target_test_sst8[i], y_test_pred_sst8[i].cpu().numpy(), reverse_decoder_index_sst8, 'SST8')

# Calculate and print the average Q3 accuracy for training and testing datasets
print(f'\nAverage Q3 Accuracy on Training SST3: {q3_accuracy_train_sst3 / N:.4f}')
print(f'Average Q3 Accuracy on Training SST8: {q3_accuracy_train_sst8 / N:.4f}')
print(f'Average Q3 Accuracy on Testing SST3: {q3_accuracy_test_sst3 / N:.4f}')
print(f'Average Q3 Accuracy on Testing SST8: {q3_accuracy_test_sst8 / N:.4f}')


Training SST3
---
Input: MNYPVNPDLMPALMAVFQHVRTRIQSELDCQRLDLTPPDVHVLKLIDEQRGLNLQDLGRQMCRDKALITRKIRELEGRNLVRRERNPSDQRSFQLFLTDEGLAIHQHAEAIMSRVHDELFAPLTPVEQATLVHLLDQCLAAQPLEDI
Target (SST3): CCCCCCCCHHHHHHHHHHHHHHHHHHHHHHCCCCCCHHHHHHHHHHHCCCCEEHHHHHHCCCCCHHHHHHHHHHHHHCCCEEEEECCCCCCCEEEEECHHHHHHHHHHHHHHHHHHHHHHCCCCHHHHHHHHHHHHHHCCCCCCCCC
Result (SST3): CCCCCCCCHHHHHHHHHHHHHHHHHHHHHHCCCCCCHHHHHHHHHHHHCCCCEHHHHHHHHCCCHHHHHHHHHHHHHCCCEEEECCCCCCCCEEEEECHHHHHHHHHHHHHHHHHHHHHHCCCCHHHHHHHHHHHHHHHCCCCCCCC
Q3 Accuracy: 0.9591836734693877
---
Input: MIQPQTYLEVADNTGARKIMCIRVLKGSNAKYATVGDVIVASVKEAIPRGAVKEGDVVKAVVVRTKKEIKRPDGSAIRFDDNAAVIINNQLEPRGTRVFGPVARELREKGFMKIVSLAPEVL
Target (SST3): CECCCCEEEECCCCCECCEEEEEECCCCCCCCECCCCEEEEEECCECCCCCCCCCCEEEEEEEECCCCEECCCCCEEEECCCEEEEECCCCCECCCCCCCCECHHHHHHCCHHHHCCCCCEC
Result (SST3): CECCCCEEEECECCCEEEEEEEEECCCCCCCCECCCCEEEEEEEEECCCCCCCCCCEEEEEEEECCCCEECCCCCEEEECCCEEEEECCCCCECCCCECCCECCHHHHHCCHHHHHHCCCEC
Q3 Accuracy: 0.9262295081967213
---
Input: SLLEFGKMILEETGKL

In [None]:
# Here we will visualize the accuracy of our model predictions based on different intervals of sequence lengths

# Create mappings for token indices to characters for decoding SST3 and SST8 predictions
reverse_decoder_index_sst3 = {value: key for key, value in tokenizer_decoder_sst3.word_index.items()}
reverse_decoder_index_sst8 = {value: key for key, value in tokenizer_decoder_sst8.word_index.items()}

# Define sequence length intervals for Q3 accuracy calculations and comparisons
intervals = [(0, 30), (31, 60), (61, 90), (91, 120), (121, 150), (151, 180), (181, 192)]
q3_acc_by_length_sst3_train = defaultdict(list)
q3_acc_by_length_sst8_train = defaultdict(list)
q3_acc_by_length_sst3_test = defaultdict(list)
q3_acc_by_length_sst8_test = defaultdict(list)

# Split training and test data to subsets of data to perform predictions as GPU cannot handle too large amounts of data at once
model.eval()
N= 1000 
X_train_set = X_train[:N].to(device)
X_test_set = X_test[:N].to(device)

# Generate predictions for SST3 and SST8 on training and testing data
with torch.no_grad():
    y_train_pred_sst3, y_train_pred_sst8 = model(X_train_set)
    y_test_pred_sst3, y_test_pred_sst8 = model(X_test_set)

# loop to perform predictions on subsets of data sequentially as GPU cannot handle too large amounts of data at once
print('=========TRAIN=========')
n= 1000
N= 2000
while True:
    if N>=len(X_train):
        N=len(X_train)
        X_train_set = X_train[n:N].to(device)
        with torch.no_grad():
            y_train_pred_sst3_temp, y_train_pred_sst8_temp = model(X_train_set)
        y_train_pred_sst3 = torch.cat((y_train_pred_sst3,y_train_pred_sst3_temp),1)
        y_train_pred_sst8 = torch.cat((y_train_pred_sst8,y_train_pred_sst8_temp),1)
        break

    X_train_set = X_train[n:N].to(device)
    with torch.no_grad():
        #memory exploded when predicting all at one shot, should change to predicting a certain number of data in that sequence length interval
        #then calculate the average Q3 accuracy for that interval
        y_train_pred_sst3_temp, y_train_pred_sst8_temp = model(X_train_set)
    y_train_pred_sst3 = torch.cat((y_train_pred_sst3,y_train_pred_sst3_temp),1)
    y_train_pred_sst8 = torch.cat((y_train_pred_sst8,y_train_pred_sst8_temp),1)

    print('n: ',n,' N: ',N)
    n=N
    N=N+1000

# loop to perform predictions on subsets of data sequentially as GPU cannot handle too large amounts of data at once
print('=========TEST=========')
n= 1000
N= 2000
while True:
    # N=35501
    if N>=len(X_test):
        N=len(X_test)
        X_test_set = X_test[n:N].to(device)
        with torch.no_grad():
            y_test_pred_sst3_temp, y_test_pred_sst8_temp = model(X_test_set)
        y_test_pred_sst3 = torch.cat((y_test_pred_sst3,y_test_pred_sst3_temp),1)
        y_test_pred_sst8 = torch.cat((y_test_pred_sst8,y_test_pred_sst8_temp),1)
        break
        
    X_test_set = X_test[n:N].to(device)
    with torch.no_grad():
        #memory exploded when predicting all at one shot, should change to predicting a certain number of data in that sequence length interval
        #then calculate the average Q3 accuracy for that interval
        y_test_pred_sst3_temp, y_test_pred_sst8_temp = model(X_test_set)
    y_test_pred_sst3 = torch.cat((y_test_pred_sst3,y_test_pred_sst3_temp),1)
    y_test_pred_sst8 = torch.cat((y_test_pred_sst8,y_test_pred_sst8_temp),1)

    print('n: ',n,' N: ',N)
    n=N
    N=N+1000

# Adjust output dimensions for processing (rearrange dimensions for interpretation)
y_train_pred_sst3=y_train_pred_sst3.permute(1,0,2)
y_train_pred_sst8=y_train_pred_sst8.permute(1,0,2)
y_test_pred_sst3=y_test_pred_sst3.permute(1,0,2)
y_test_pred_sst8=y_test_pred_sst8.permute(1,0,2)

def get_interval(length):
    for start, end in intervals:
        if start <= length <= end:
            return (start, end)
    return None

# Converts one-hot encoded sequences back to the original sequence using an index mapping
def onehot_to_seq(oh_seq, index):
    s = ''
    for o in oh_seq:
        i = np.argmax(o)
        if i != 0:
            s += index[i]
        else:
            break
    return s

# Calculates Q3 accuracy by comparing predicted and true sequences character-by-character
def q3_acc_strings(y_true_str, y_pred_str):
    min_length = min(len(y_true_str), len(y_pred_str))
    correct_matches = sum(1 for i in range(min_length) if y_true_str[i] == y_pred_str[i])
    return correct_matches / min_length if min_length > 0 else 0.0

#training SST3
for i in range(N):
    seq_length = len(seq_train[i])
    interval = get_interval(seq_length)
    if interval:
        q3_acc = q3_acc_strings(target_train_sst3[i], onehot_to_seq(y_train_pred_sst3[i].cpu().numpy(), reverse_decoder_index_sst3).upper())
        q3_acc_by_length_sst3_train[interval].append(q3_acc)

#training SST8
for i in range(N):
    seq_length = len(seq_train[i])
    interval = get_interval(seq_length)
    if interval:
        q3_acc = q3_acc_strings(target_train_sst8[i], onehot_to_seq(y_train_pred_sst8[i].cpu().numpy(), reverse_decoder_index_sst8).upper())
        q3_acc_by_length_sst8_train[interval].append(q3_acc)

#testing SST3
for i in range(N):
    seq_length = len(seq_test[i])
    interval = get_interval(seq_length)
    if interval:
        q3_acc = q3_acc_strings(target_test_sst3[i], onehot_to_seq(y_test_pred_sst3[i].cpu().numpy(), reverse_decoder_index_sst3).upper())
        q3_acc_by_length_sst3_test[interval].append(q3_acc)

#testing SST8
for i in range(N):
    seq_length = len(seq_test[i])
    interval = get_interval(seq_length)
    if interval:
        q3_acc = q3_acc_strings(target_test_sst8[i], onehot_to_seq(y_test_pred_sst8[i].cpu().numpy(), reverse_decoder_index_sst8).upper())
        q3_acc_by_length_sst8_test[interval].append(q3_acc)


# Plotting setup for Q3 accuracy by sequence length intervals
train_sst3_avg_ls = []
train_sst8_avg_ls = []
test_sst3_avg_ls = []
test_sst8_avg_ls = []

# Compute and display average Q3 accuracy for each interval
for interval in intervals:
    train_sst3_avg = np.mean(q3_acc_by_length_sst3_train[interval]) if q3_acc_by_length_sst3_train[interval] else None
    train_sst3_avg_ls.append(train_sst3_avg)
    train_sst8_avg = np.mean(q3_acc_by_length_sst8_train[interval]) if q3_acc_by_length_sst8_train[interval] else None
    train_sst8_avg_ls.append(train_sst8_avg)
    test_sst3_avg = np.mean(q3_acc_by_length_sst3_test[interval]) if q3_acc_by_length_sst3_test[interval] else None
    test_sst3_avg_ls.append(test_sst3_avg)
    test_sst8_avg = np.mean(q3_acc_by_length_sst8_test[interval]) if q3_acc_by_length_sst8_test[interval] else None
    test_sst8_avg_ls.append(test_sst8_avg)
    
    print(f"Interval {interval}:")
    print(f"  Training SST3 Avg Q3 Accuracy: {train_sst3_avg:.4f}" if train_sst3_avg else "  No Training SST3 data")
    print(f"  Training SST8 Avg Q3 Accuracy: {train_sst8_avg:.4f}" if train_sst8_avg else "  No Training SST8 data")
    print(f"  Testing SST3 Avg Q3 Accuracy: {test_sst3_avg:.4f}" if test_sst3_avg else "  No Testing SST3 data")
    print(f"  Testing SST8 Avg Q3 Accuracy: {test_sst8_avg:.4f}" if test_sst8_avg else "  No Testing SST8 data")

# Plotting frequency of input peptide sequence lengths
bins = [0, 30, 60, 90, 120, 150, 180, 192]
df_trimmed_plot = df[(df.len <= maxlen_seq) & (~df.has_nonstd_aa)]
df_trimmed_plot['length_bin'] = pd.cut(df_trimmed_plot['len'], bins=bins)
df_trimmed_plot['length_bin'].value_counts(sort=False).plot(kind='bar', edgecolor='black')
plt.xlabel('Length of Input Peptide Sequence')
plt.ylabel('Frequency')
plt.xticks(rotation=45)
plt.savefig("pred_input_hist_plot.png")
plt.show

# Plot training and testing Q3 accuracy by sequence length intervals
intervals_graph = ['(0, 30)', '(31, 60)', '(61, 90)', '(91, 120)', '(121, 150)', '(151, 180)', '(181, 192)']
fig, axs = plt.subplots(2, 2, figsize=(15, 8))
axs[0, 0].bar(intervals_graph, train_sst3_avg_ls, color ='skyblue')
axs[0, 0].set_title('Training SST3 Avg Q3 Accuracy')
axs[0, 1].bar(intervals_graph, train_sst8_avg_ls, color ='orange')
axs[0, 1].set_title('Training SST8 Avg Q3 Accuracy')
axs[1, 0].bar(intervals_graph, test_sst3_avg_ls, color ='skyblue')
axs[1, 0].set_title('Testing SST3 Avg Q3 Accuracy')
axs[1, 1].bar(intervals_graph, test_sst8_avg_ls, color ='orange')
axs[1, 1].set_title('Testing SST8 Avg Q3 Accuracy')

# Set labels for plots
for ax in axs.flat:
    ax.set(xlabel='Sequence Length Interval', ylabel='Average Accuracy')

for ax in axs.flat:
    ax.label_outer()

# Save the final plot
plt.savefig("Train_Test_Accuracy.png")

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.46 GiB. GPU 