In [326]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import torchvision

import torchvision.transforms as transforms

from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt

import numpy as np

import math

In [327]:
torch.cuda.is_available()

True

## Attention Scoring (Scaled Dot Product)

In [328]:
class ScaledDotProductAttention(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        # self.dropout = nn.Dropout(dropout)
        
    def forward(self, queries, keys, values):
        
        weights = torch.bmm(queries, torch.transpose(keys, 1, 2)) / math.sqrt(queries.shape[-1])
        
        self.attn_weights = F.softmax(weights, dim = -1)
        
        scores = torch.bmm(self.attn_weights, values)
        
        return scores    
    

## Multi Heads Attention

In [329]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, dimensions, num_heads, bias = False):
        
        super().__init__()
        
        self.num_heads = num_heads
        
        self.attention = ScaledDotProductAttention()
        
        self.queries_weights = nn.LazyLinear(dimensions, bias = bias)
        
        self.keys_weights = nn.LazyLinear(dimensions, bias = bias)
        
        self.values_weights = nn.LazyLinear(dimensions, bias = bias)
        
        self.heads_weights = nn.LazyLinear(dimensions, bias = bias)
        
    def reshape_multi_heads(self, inputs):
        
        # batch size, number of queries/keys/values, number of heads, dimensions / number of heads
        
        # 64 x 28 x 4 x 7
        
        inputs = inputs.reshape(inputs.shape[0], inputs.shape[1], self.num_heads, -1)
        
        # batch size, number of heads, number of queries/keys/values, dimensions / number of heads
        
        # 64 x 4 x 28 x 7
        
        inputs = inputs.permute(0, 2, 1, 3)
        
        # batch size x number of heads, number of queries/keys/values, dimensions / number of heads
        
        # 256 x 28 x 7 
        
        inputs = inputs.reshape(-1, inputs.shape[2], inputs.shape[3])
        
        return inputs
    
    def reshape_output(self, outputs):
        
        # batch size, number of heads, number of queries/keys/values, dimensions / number of heads
        
        # 64 x 4 x 28 x 7
        
        outputs = outputs.reshape(-1, self.num_heads, outputs.shape[1], outputs.shape[2])
        
        # batch size, number of queries/keys/values, number of heads, dimensions / number of heads
        
        # 64 x 28 x 4 x 7
        
        outputs = outputs.permute(0, 2, 1, 3)
        
        # batch size, number of queries/keys/values, dimensions
        
        # 64 x 28 x 28
        
        outputs = outputs.reshape(outputs.shape[0], outputs.shape[1], -1)
        
        return outputs
        
    def forward(self, queries, keys, values):
        
        queries = self.reshape_multi_heads(self.queries_weights(queries))
        
        keys = self.reshape_multi_heads(self.keys_weights(keys))
        
        values = self.reshape_multi_heads(self.values_weights(values))
        
        outputs = self.attention(queries, keys, values)

        outputs = self.reshape_output(outputs)
        
        outputs = self.heads_weights(outputs)
        
        return outputs

## Initialize Latent Array

In [330]:
class LatentArray(nn.Module):
    
    def __init__(self, latent_size, latent_dim):
        
        super().__init__()
        
        self.latent_size = latent_size
        
        self.latent_dim = latent_dim
        
    def forward(self, inputs):
        
        latent_array = nn.Parameter(torch.rand(self.latent_size, self.latent_dim))
        
        torch.nn.init.trunc_normal_(latent_array, mean = 0.0, std = 0.02, a = -2.0, b = 2.0)
        
        batch_latent_array = latent_array.repeat(inputs.shape[0], 1, 1)
        
        return batch_latent_array   

# Perceiver Block (1 layer of Encoder)

## Residual Connection + Normalization

In [331]:
class AddNorm(nn.Module):
    
    def __init__(self, norm_shape):
        
        super().__init__()
        
        self.norm = nn.LayerNorm(norm_shape)
        
    def forward(self, inputs, outputs):
        
        return self.norm(outputs + inputs)

## Fully Connected FeedForward Neural Network Layer

In [332]:
class FFNLayer(nn.Module):
    
    def __init__(self, hidden_size, output_size):
        
        super().__init__()
        
        self.layer1 = nn.LazyLinear(hidden_size)
        
        self.layer2 = nn.LazyLinear(output_size)
        
    def forward(self, inputs):
        
        first_outputs = F.relu(self.layer1(inputs))
        
        final_outputs = self.layer2(first_outputs)
        
        return final_outputs

In [333]:
%run Data_Processing.ipynb

Mapping files: 100%|█████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 69.81it/s]
Mapping files: 100%|█████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 58.85it/s]


The proportion of 0's in the genomic dataset is 33.34 %
The proportion of 1's in the genomic dataset is 33.34 %
The proportion of 2's in the genomic dataset is 33.32 %
The proportion of NaN's in the genomic dataset is 0.00 %
The proportion of having gout is 0.41


Mapping files: 100%|█████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 47.65it/s]
Mapping files: 100%|█████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 46.90it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['encoded_sex'] = np.select([df['sex'] == "Male" , df['sex'] == "Female"], [1, 2])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['encoded_sex'] = np.select([df['sex'] == "Male" , df['sex'] == "Female"], [1, 2])


0.6565656565656566
              precision    recall  f1-score   support

       False     0.6874    0.7312    0.7087      1131
        True     0.6088    0.5571    0.5818       849

    accuracy                         0.6566      1980
   macro avg     0.6481    0.6442    0.6452      1980
weighted avg     0.6537    0.6566    0.6543      1980



STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


## One Hot Encoding of the tokenised SNV data

In [334]:
class One_Hot_Encoding(nn.Module):
    
    def __init__(self, num_classes: int):
        
        super().__init__()
        
        self.num_classes = num_classes
        
    def forward(self, input_ts):
        
        return F.one_hot(input_ts.long(), num_classes = self.num_classes)

## Positional Encoding

In [335]:
class Positional_Encoding(nn.Module):
    
    def __init__(self, pos_dim, dropout):
        
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.pos_dim = pos_dim
        
    def forward(self, input_ts):
        
        num_len = input_ts.shape[1]
        
        #num_dim = input_ts.shape[2]
        
        P = torch.zeros(1, num_len, self.pos_dim)
        
        X = (torch.arange(num_len, dtype=torch.float32).reshape(-1, 1) / 
        torch.pow(10000, torch.arange(0, self.pos_dim, 2, dtype = torch.float32) / self.pos_dim))
        
        P[:, :, 0::2] = torch.sin(X)
        
        if (self.pos_dim % 2) == 1:
            
            P[:, :, 1::2] = torch.cos(X[:, :-1])
            
        else:    
            
            P[:, :, 1::2] = torch.cos(X)
        
        P = torch.repeat_interleave(P, repeats = input_ts.shape[0], dim = 0)
        
        return self.dropout(P[:, :input_ts.shape[1], :].to(input_ts.device))

## Processing the raw datasets

In [336]:
class prepare_dataset(nn.Module):
    
    def __init__(self, num_classes: int, phenos_or_not: bool, pos_dim: int, dropout):
        
        super().__init__()
        
        self.one_hot_encoding = One_Hot_Encoding(num_classes)
        
        self.positional_encoding = Positional_Encoding(pos_dim, dropout)
        
        self.num_classes = num_classes
        
        self.phenos_or_not = phenos_or_not
        
    def forward(self, snv_ts, phenos_ts):
        
        # Process tokenised SNV data
        
        encoded_snv_ts = self.one_hot_encoding(snv_ts)
        
        pos_encoded_snv_ts = self.positional_encoding(encoded_snv_ts)
        
        if self.phenos_or_not:
            
            # Process phenotypes data
            
            phenos_ts = phenos_ts.unsqueeze(0)
            
            phenos_ts = torch.repeat_interleave(phenos_ts, repeats = snv_ts.shape[1], dim = 1)
            
            phenos_ts = phenos_ts.reshape(-1, snv_ts.shape[1], phenos_ts.shape[2])
            
            # Join encoded SNV and phenotypes datasets
            
            complete_ts = torch.cat((encoded_snv_ts, pos_encoded_snv_ts, phenos_ts), dim = 2)
            
        else:
            
            complete_ts = torch.cat((encoded_snv_ts, pos_encoded_snv_ts), dim = 2)
            
            # complete_ts = snv_ts.to(torch.float32)
            
        return complete_ts
            

## Perceiver Block

In [337]:
class PerceiverLayer(nn.Module):
    
    def __init__(self, dimensions, num_heads, hidden_ffn_size, cross_attention_modules, self_attention_modules):
        
        super().__init__()
        
        self.attention = MultiHeadAttention(dimensions, num_heads)
        
        self.addnorm = AddNorm(dimensions)
        
        self.ffn = FFNLayer(hidden_ffn_size, dimensions)
        
        #self.cross_modules = nn.Sequential()
        
        #self.self_modules = nn.Sequential()
        
        self.cross_attn_mods = []
        
        self.self_attn_mods = []
        
        for i in range(cross_attention_modules):
            
            cross_mod = MultiHeadAttention(dimensions, num_heads, bias = False)
            
            self.cross_attn_mods.append(cross_mod)
            
        self.cross_attn_mods = nn.ModuleList(self.cross_attn_mods)
            
        for i in range(self_attention_modules):
            
            self_mod = MultiHeadAttention(dimensions, num_heads, bias = False)
            
            self.self_attn_mods.append(self_mod)
            
        self.self_attn_mods = nn.ModuleList(self.self_attn_mods)
        
        #for i in range(cross_attention_modules):
        #    self.cross_modules.add_module("cross attention"+str(i), MultiHeadAttention(dimensions, num_heads, bias = False))
            
        #for i in range(self_attention_modules):
        #   self.self_modules.add_module("self attention"+str(i), MultiHeadAttention(dimensions, num_heads, bias = False))
        
    def forward(self, latent_inputs, inputs):
        
        dup_inputs = latent_inputs.detach().clone()
        
        for cross_attention in self.cross_attn_mods:
            
            latent_inputs = cross_attention(latent_inputs, inputs, inputs)
            
        for self_attention in self.self_attn_mods:
            
            latent_inputs = self_attention(latent_inputs, latent_inputs, latent_inputs)
        
        outputs_norm = self.addnorm(dup_inputs, latent_inputs)
        
        outputs_ffn = self.ffn(outputs_norm)
        
        outputs_final = self.addnorm(outputs_norm, outputs_ffn)
        
        return outputs_final

## Perceiver Model

In [338]:
class Perceiver(nn.Module):
    
    def __init__(self, dimensions, num_heads, hidden_ffn_size, num_blocks, latent_size, latent_dim,
                cross_attention_modules, self_attention_modules):
        
        super().__init__()
        
        self.latent = LatentArray(latent_size, latent_dim)
        
        self.softmax = nn.Softmax(dim=1)
        
        #self.blocks = nn.Sequential()
        
        self.ffn = FFNLayer(hidden_ffn_size, 2)
        
        self.blocks = []
        
        for i in range(num_blocks):
            
            perceiver_block = PerceiverLayer(dimensions, num_heads, hidden_ffn_size,
                                             cross_attention_modules, self_attention_modules)
            
            self.blocks.append(perceiver_block)
            
        self.blocks = nn.ModuleList(self.blocks)
        
        #for i in range(num_blocks):
        #    self.blocks.add_module("block"+str(i), PerceiverLayer(dimensions, num_heads, hidden_ffn_size,
        #                                                         cross_attention_modules, self_attention_modules))
        
    def forward(self, inputs):
        
        # Initialize Latent Array
        
        latent_inputs = self.latent(inputs).to("cuda:0")
        
        for perceiver_block in self.blocks:
            
            inputs = perceiver_block(latent_inputs, inputs)
            
        outputs = inputs.reshape(inputs.shape[0], -1)
            
        # outputs = F.relu(inputs)
        
        outputs_ffn = F.relu(self.ffn(outputs))
        
        outputs_final = self.softmax(outputs_ffn)
              
        return outputs_final

In [339]:
pros = prepare_dataset(num_classes = 3, phenos_or_not = False, pos_dim = 10, dropout = 0.2)

In [340]:
train_df = pros(tr_snv, tr_phenos)

In [342]:
#train_df = train_df.unsqueeze(2)

In [343]:
train_df.shape

torch.Size([9899, 2000, 13])

In [344]:
test_df = pros(te_snv, te_phenos)

In [345]:
#test_df = test_df.unsqueeze(2)

In [346]:
test_df.shape

torch.Size([101, 2000, 13])

In [347]:
tr_gout.shape

torch.Size([9899])

In [353]:
te_gout.shape

torch.Size([101])

In [354]:
# Merge data and labels

train_ts = TensorDataset(train_df, tr_gout)

test_ts = TensorDataset(test_df, te_gout)

## Test Run the Perceiver Model

In [355]:
# Loading training and testing datasets

train_dataloader = DataLoader(train_ts, batch_size = 10, shuffle = True)

#label_tr_dataloader = torch.utils.data.DataLoader(tr_gout, batch_size = 10, shuffle = False)

test_dataloader = DataLoader(test_ts, batch_size = 4, shuffle = False)

#label_te_dataloader = torch.utils.data.DataLoader(te_gout, batch_size = 4, shuffle = False)

In [356]:
model = Perceiver(dimensions = 6, num_heads = 2, hidden_ffn_size = 32, num_blocks = 2, latent_size = 5,
                 latent_dim = 6, self_attention_modules = 1, cross_attention_modules = 1)



In [357]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-5, amsgrad = True)

epochs = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [358]:
model = model.to(device)

In [359]:
n_total_steps = len(train_dataloader)

In [360]:
import time

start_time = time.time()

In [None]:
model.train()

for epoch in range(epochs):
    
    for i, (seqs, labels) in enumerate(train_dataloader):
        
        # labels = torch.unsqueeze(labels, 1)
        
        #images = torch.unsqueeze(images, 1)
        
        seqs = seqs.to(device)
        
        labels = labels.to(device)
        
        outputs = model(seqs)
        
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        
        loss.backward()
        
        optimizer.step()
        
        if (i+1) % 200 == 0:
            print (f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.3f}')

Epoch [1/5], Step [200/990], Loss: 0.693
Epoch [1/5], Step [400/990], Loss: 0.693
Epoch [1/5], Step [600/990], Loss: 0.693
Epoch [1/5], Step [800/990], Loss: 0.693
Epoch [2/5], Step [200/990], Loss: 0.693
Epoch [2/5], Step [400/990], Loss: 0.693
Epoch [2/5], Step [600/990], Loss: 0.698
Epoch [2/5], Step [800/990], Loss: 0.693
Epoch [3/5], Step [200/990], Loss: 0.693
Epoch [3/5], Step [400/990], Loss: 0.693
Epoch [3/5], Step [600/990], Loss: 0.693
Epoch [3/5], Step [800/990], Loss: 0.693


In [None]:
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:
classes = ('0', '1')

In [None]:
model.eval()

# Evaluate the trained model performance

with torch.no_grad():
    
    batch_size = 1
    
    num_correct_preds = 0
    
    num_total = len(test_df)
    
    num_correct_per_label = [0] * len(classes)
    
    num_total_per_label = [0] * len(classes)
    
    for i, (seqs, labels) in enumerate(test_dataloader):
        
        #images = torch.unsqueeze(images, 1)
        
        seqs = seqs.to(device)
        
        labels = labels.to(device)
        
        outputs = model(seqs)
        
        # Return the value with the highest probability score
        
        _, pred_values = torch.max(outputs, 1)
        
        num_correct_preds += (pred_values == labels).sum().item()        
    
        for i in range(batch_size):
            
            label = labels[i]
            
            pred_val = pred_values[i]
            
            num_total_per_label[label] += 1 
            
            if label == pred_val:
                
                num_correct_per_label[label] += 1
                
    # Calculate Overall Accuracy
    
    overall_accuracy = 100.0 * num_correct_preds / num_total
    
    print(f'Overall accuracy of the model: {overall_accuracy:.3f} %')
    
    # Calculate Accuracy per Label
    
    for i in range(len(classes)):
        
        accuracy_per_label = 100.0 * num_correct_per_label[i] / num_total_per_label[i]
        
        print(f'Accuracy of Label {classes[i]} : {accuracy_per_label:.3f} %')