# Predicting transmembrane protein topology from 3D structure

## **Introduction**







1.**Protein Topology and Its Importance**

Protein topology tells about the location of a particular protein part concerning a cell. A protein can be inside or outside a cell, but it can also be in-between the cell membrane as illustrated in the following Figure. A transmembrane protein might serve as a receptor for cell-to-cell communication while a free-moving protein outside or inside a cell might play a role in intracellular transportation [1]

2.**Advances in Protein Topology Prediction**

- Bernsel et al. (2009) utilized an ensemble of five different 1D sequence models for topology prediction [2].
- Dobson et al. (2015) improved upon this with a similar approach, but employing ten models [3].
- Recently, Hallgren et al. (2018) demonstrated significant performance improvements with the DeepTMHMM model, a neural network-based approach [4].

3.**Our Approach: A GNN-based Method**

However, existing models primarily leverage only 1D protein sequence information for topology prediction. In contrast, our study introduces a Graph Neural Network (GNN)-based method that uses 3D structural information obtained from the well-known Alphafold model developed by DeepMind.

|                    |
|:-------------------:|
| <p align="center"><img src="image/figure1.png" alt="Image" width="500" height="250"></p> |



## **Dataset \& 3D protein structure**




#### **Protein Types and 3D Predictions Availability**
The dataset used in this study is the same as the study behind the DeeoTMHMM model. Following table provides an overview of which protein types are included and whether or not their 3D predictions are available from the AlphaDB.


The 3D protein structures corresponding to the sequences in Table are predicted by Alphafold and acquired from the Alphafold Data Base [4,5]. These structures are stored in pdb files, which each contains the protein 1D sequence, the atoms of the protein, the atom coordinates and the prediction uncertainties. 

> **Note**: TM: Transmembrane, SP: Signal Peptide. The numbers without parentheses indicate sequences with both 1D and 3D structures available.

| Protein Type | alpha TM | alpha SP+TM | beta barrel | Globular   | SP+Globular |
| ------------ | -------- | ----------- | ----------- | ---------- | ----------- |
| Amount       | 383 (387)| 102 (106)   | 82 (82)     | 1,982 (2,000) | 994 (1,000) |







#### **Graph Construction from PDB Files**

A graph in the form of an adjacency matrix can then be constructed based on the atom coordinates (see block 3 in follwoing figure). The atoms include carbon (C), nitrogen (N), oxygen (O) and sulfur (S) while Hydrogen (H) is omitted due to its relatively low importance. Upon data inspection, it has been observed every residue (bound amino acid) would start with the nitrogen atom in the graph. 


|                    |
|:-------------------:|
| <p align="center"><img src="image/figure2.png" alt="Image" width="500" height="250"></p> |



## **Methodology**

The output of an *invariant* GNN always has the same order while the output of an *equivariant* GNN has the same order as the input order. In this study, we applied the state-of-the-art GNNs such as **SchNet (invariant)**, **EGNN (equivariant)**, and **GCPNet (equivariant)** to the topological task. However, SchNet was the only GNN that we managed to get meaningful results from.


 The SchNet model utilizing major voting method in the downstream task  uses a 5-fold cross-validation (CV) to assess model performance, SchNet model utilizing major voting $\alpha$-carbon method in the downstream task,EGNN and GCPNet use a 1-fold cross-validation (CV) and are compared to the SchNet model.(We utilized the identical data splits as those employed in the DeepTMHMM study for this configuration. Consequently, the model was trained on the first three sets, validated on the fourth set, and tested on the fifth set. This process was repeated five times, with the model being validated and tested on different sets each time.)


 

### **Schent Model**

#### **Model Architecture**

1. **Input and Embedding**: 

    This includes representing atoms and their features (\( t_1, ..., t_n \)) as embedding vectors. The features of each atom can include its type, position, and other relevant properties. These features are subsequently transformed into continuous vector space (denoted as 'embedding, k') for further processing.

2. **Message Passing Layers**:

     These layers are responsible for encoding interactions between atoms. In each layer (labeled as 'interaction, k'), the network updates the representation of each atom by aggregating information from its neighbors. This process allows the model to capture the local chemical environment of each atom within the molecule.

3. **Continuous Filter Convolution (cfconv)**: 

    Convolutional filters, learned during network training, are applied to interactions between atoms. The filters operate based on the distances between atoms, typically represented using radial basis functions (RBFs), and are used to model spatial relationships between atoms.

4. **Nonlinear Activation (Shifted Softplus)**: 

    After the convolutional step, a nonlinear activation function (shifted softplus) is applied to introduce nonlinearity, enabling the model to capture more complex patterns in the data.

5. **Atomic Level Layers**: 

    Following the convolutional and activation steps, atomic-level layers are applied. These are fully connected layers that operate independently on each atom, allowing the model to refine the representation of each atom's features.

6. **Sum Pooling**: 

    After processing features through multiple layers, the model aggregates the features of all atoms in the molecule through sum pooling. This operation generates a global representation of the molecule for further predictions.

7. **Feedforward Output Neural Network**: 

    The global representation is passed to a feedforward neural network to predict the target property.

8. **Parameter Sharing**: 

    Parameters are often shared between different parts of the network (e.g., weights in convolutional filters). This sharing allows the network to generalize better and reduces the number of parameters that need to be learned.


|                    |
|:-------------------:|
| <p align="center"><img src="image/schnet.png" alt="Image" width="500" height="250"></p> [5]|


#### **Model Parameter**

| Model Parameter                | Value                                  | Description                                           |
|--------------------------------|----------------------------------------|-------------------------------------------------------|
| Number of Layers               | 6                                      | Number of layers in the SchNet model.                 |
| Hidden Embedding Size          | 128                                    | Size of the hidden embedding for each atom.           |
| Number of Convolutional Filters| 128                                    | Number of convolutional filters used in the model.     |
| Maximum Number of Neighbor Nodes| 32                                    | Maximum number of neighbor nodes considered during message passing. |
| Dropout                        | 0.8                                    | Dropout rate applied before the fully connected layers. |
| Batch Size                     | 1 (chosen for best and stable performance) | Batch size used for training.                         |
| Optimizer                      | Adam                                   | Optimization algorithm used for training.             |
| Learning Rate                  | $3\cdot 10^{-4}$                       | Learning rate for training.                           |
| $L2$ Regularization            | $1\cdot 10^{-4}$                       | Strength of $L2$ regularization.                     |
| Learning Rate Scheduler        | Exponential with decay rate of 0.1     | Learning rate scheduler applied during training.      |
| StaticEmbedding length         | 30000                                  | Length of static embedding.                           |
| Gaussian filter during the training and validation | Kernel size = 29 and $\sigma$ = 5 | Parameters for Gaussian smoothing during training and validation. |


#### **Downstream Tasks**

For our classification task predicting transmembrane proteins, we have a total of six categories.(The residue classes are *Signal peptide (S)* *Inside cytosol (I)* *Alpha membrane part (M)* *Beta membrane part (B)* *Periplasm (P)* *Outside cell (O)*) Thus, we have added two fully connected layers after the SCHNET model. The activation function employed is ReLU, and the output size is set to 6, corresponding to the number of classification categories. Furthermore, we have applied Gaussian Smoothing to the output (the `GaussianSmoothing` part is included in `from task import GaussianSmoothing`).

##### **Comparing with Baseline mode**
The model performance was compared to a baseline classifier, which predicted all test observations as belonging to the most frequently appearing class in the training set.We observed that in all the CV folds the class inside cytosol (I) was the most frequently appearing class and the class distribution was very similar for all the CV folds. The comparison was conducted using McNemars test [6]. The below pictures shows the distribution of all the label classes in the first fold setup and the distributions in the other CV folds are very similar (not shown)


|                    |
|:-------------------:|
| <p align="center"><img src="image/setup1dis.png" alt="Image" width="500" height="333"></p> 






##### **Label Alignment for Major Voting**

Given that the model's input originates from the Alphafold model developed by DeepMind—which cannot predict protein structures with complete accuracy—there are disparities between the model's predicted labels (Alphafold's entire sequence of predicted atoms) and the original labels (from the actual protein atomic sequences). One of the downstream tasks is, therefore, to align these labels.

Label Alignment Strategies are as follow:

1. **First Approach**: Keep the predicted labels static and process the real labels to facilitate the backward propagation process. This step is implemented in `from data_utils import DismatchIndexPadRawData`. The primary idea is to compare the two sets of labels one by one, identifying the index corresponding to the actual label. If the actual label is greater than the predicted label, the index corresponding to the real label is deleted. Conversely, if the actual label is smaller, the value of the predicted label is inserted at the index of the actual label. The type of transmembrane protein at that position is inferred based on the indices immediately preceding and following the true label.

2. **Second Approach**: Maintain the actual labels unchanged and process the predicted labels to achieve atomic-level accuracy. Following the indices obtained from the first approach, if the actual label is larger than the predicted label, the corresponding value is inserted into the predicted label. If the actual label is smaller, the corresponding value in the predicted label is removed.

3. **Third Approach**: Based on the second strategy, the predicted labels are segmented according to the length of atoms within each amino acid (proteins are composed of amino acids, which in turn consist of multiple atoms). After segmentation, a majority voting technique is used to determine the category of each atom within the predicted amino acid. The transmembrane category of the amino acid is decided by the majority vote, achieving residue-level accuracy (amino acid level). This part of the process is covered in `from task import MapAtomNode`.

##### **Label Alignment for $\alpha$-carbon**
1. **First Approach**: This part is the same as in Label Alignment for Major Voting

2. **Second Approach**: This part is the same as in Label Alignment for Major Voting


3. **Third Approach**:Another approach is to train the GNN using the $\alpha$-carbon embedding from each residue[7] . An $\alpha$-carbon is the central atom linking to an amino group and a carboxyl group within an amino acid[1] . In this way, the output dimension based on the $\alpha$-carbons would have the same dimension as the protein sequence. However, we only achieved similar or worse performance by using $\alpha$-carbons compared to using the aforementioned major voting approach.




In [None]:
from task import CreateDataLabel,MapAtomNode,node_accuracy
from schnet import SchNetModel
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter1d
import json
import pandas as pd

#### **Model Training and Model Validation**

##### **Results from SCHNET Model**

**Major Voting(5-fold cross-validation)**



|<img src="image/schnet_loss.png" alt="Image 1" width="400" height="200"> | <img src="image/schnet_acc.png" alt="Image 2" width="400" height="200">|



1. **Result:**

- The residue level accuracies are evaluated by directly comparing them to the class labels. It can be seen that all the loss and accuracy curves have similar tendencies, indicating that the variability is small in each CV setup.

- The abrupt jump around epoch 100 is caused by the exponential decaying learning scheduler. Please notice that the training for CV setup 5 was early stopped around epoch 90.

**$\alpha$-carbon(1-fold cross_validation)**


|<img src="image/schnetCA_loss.png" alt="Image 1" width="400" height="200"> | <img src="image/schnetCA_acc.png" alt="Image 2" width="400" height="200">|


 1. **Result:**
    
- While the loss curves converge at a lower value compared to the major voting approach, the actual correctly predicted residue labels are not better. However, this method might still be useful in improving the overall topological prediction.

##### **Training Techniques and Strategies**

 **1.EarlyStopping and Learning Rate Scheduling**

In [None]:
# EarlyStopping:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss >= (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
# Learning Rate Scheduling
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.05)

**2.GaussianSmoothing**

In [None]:
class GaussianSmoothing(nn.Module):
    """
    Apply gaussian smoothing on a
    1d, 2d or 3d tensor. Filtering is performed seperately for each channel
    in the input using a depthwise convolution.
    Arguments:
        channels (int, sequence): Number of channels of the input tensors. Output will
            have this number of channels as well.
        kernel_size (int, sequence): Size of the gaussian kernel.
        sigma (float, sequence): Standard deviation of the gaussian kernel.
        dim (int, optional): The number of dimensions of the data.
            Default value is 2 (spatial).
    """
    def __init__(self, channels, kernel_size, sigma, dim=2):
        super(GaussianSmoothing, self).__init__()
        if isinstance(kernel_size, numbers.Number):
            kernel_size = [kernel_size] * dim
        if isinstance(sigma, numbers.Number):
            sigma = [sigma] * dim

        # The gaussian kernel is the product of the
        # gaussian function of each dimension.
        kernel = 1
        meshgrids = torch.meshgrid(
            [
                torch.arange(size, dtype=torch.float32)
                for size in kernel_size
            ]
        )
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
                      torch.exp(-((mgrid - mean) / std) ** 2 / 2)

        # Make sure sum of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

        self.register_buffer('weight', kernel)
        self.groups = channels

        if dim == 1:
            self.conv = F.conv1d
        elif dim == 2:
            self.conv = F.conv2d
        elif dim == 3:
            self.conv = F.conv3d
        else:
            raise RuntimeError(
                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
            )

    def forward(self, input):
        """
        Apply gaussian filter to input.
        Arguments:
            input (torch.Tensor): Input to apply gaussian filter on.
        Returns:
            filtered (torch.Tensor): Filtered output.
        """
        return self.conv(input, weight=self.weight, groups=self.groups)

##### **Major Voting(5-fold cross-validation)**

we just show that how to run *setup1* in the follow code, from setup2 to setup5 are the same as setup1



In [None]:
def cv_data_generator(cv0,cv_list,total_df,total_data,total_label,total_batchname,total_dismatch_index_pred,total_dismatch_index_type,total_real_node_label):
  for cv0 in cv_list:
    cv0_data = []
    cv0_label = []
    cv0_batchname = []
    cv0_dismatch_index_pred = {}
    cv0_dismatch_index_type = {}
    cv0_real_node_label = []
    cv0_index = []
    cv0_df = pd.DataFrame(columns=total_df.columns)
    # The batch names have the same order as the data
    for i in range(0, len(cv0)):
      try:
        cv0_batchname.append([total_df.index[total_df['uniprot_id_low'] == cv0[i]['id'].lower()].tolist()[0].lower()])
      except:
        pass
  # Find index for the found cv0 proteins
  # Note that the labels are aligned with the batch names
    for i in range(0, len(cv0)):
      try:
        cv0_index.append(total_batchname.index([cv0[i]['id'].lower()]))
      except:
        pass
    # gather the data for cv0
    for i in range(0, len(cv0_index)):
      cv0_data.append(total_data[cv0_index[i]])
      cv0_label.append(total_label[cv0_index[i]])
      cv0_dismatch_index_pred[list(total_dismatch_index_pred)[cv0_index[i]]] = list(total_dismatch_index_pred.values())[cv0_index[i]]
      cv0_dismatch_index_type[list(total_dismatch_index_type)[cv0_index[i]]] = list(total_dismatch_index_type.values())[cv0_index[i]]
      cv0_real_node_label.append(total_real_node_label[cv0_index[i]])

      #cv0_df = pd.concat([cv0_df, total_df.loc[[total_df.iloc[cv0_index[i]]["uniprot_id"]]]], ignore_index=False)
      cv0_df = pd.concat([cv0_df, total_df.loc[[cv0_batchname[i][0].upper()]]], ignore_index=False)
    return cv0_data, cv0_label, cv0_batchname, cv0_dismatch_index_pred, cv0_dismatch_index_type, cv0_real_node_label, cv0_df

In [None]:
batch_size = 1
raw_data_name = "DeepTMHMM.3line"
path ='/work3/s194408/Project/'
processor = CreateDataLabel(path,batch_size =batch_size,raw_data_name=raw_data_name)
# processor.initialization()# split and download trian/val/test just once
train_data,train_lable, train_batchname, train_max_len,train_dismatch_index_pred,train_dismatch_index_type,train_real_node_label,df_train = processor.datalabelgenerator('train')
val_data,val_lable, val_batchname, val_max_len,val_dismatch_index_pred,val_dismatch_index_type,val_real_node_label,df_val = processor.datalabelgenerator('val')
test_SP_TM_data,test_SP_TM_lable, test_SP_TM_batchname, test_SP_TM_max_len,test_SP_TM_dismatch_index_pred,test_SP_TM_dismatch_index_type,test_SP_TM_real_node_label,df_test_SP_TM = processor.datalabelgenerator('test_SP_TM')
test_TM_data,test_TM_lable, test_TM_batchname, test_TM_max_len,test_TM_dismatch_index_pred,test_TM_dismatch_index_type,test_TM_real_node_label,df_test_TM = processor.datalabelgenerator('test_TM')
test_BETA_data,test_BETA_lable, test_BETA_batchname, test_BETA_max_len,test_BETA_dismatch_index_pred,test_BETA_dismatch_index_type,test_BETA_real_node_label,df_test_BETA = processor.datalabelgenerator('test_BETA')
# Opening JSON file
f = open('/work3/s194408/Project/dataset/tmp/DeepTMHMM.partitions.json')
# returns JSON object as
cv_data = json.load(f)
cv0 = cv_data['cv0']
cv1 = cv_data['cv1']
cv2 = cv_data['cv2']
cv3 = cv_data['cv3']
cv4 = cv_data['cv4']
# Group the data together
total_data = train_data.copy()
total_label = train_lable.copy()
total_batchname = train_batchname.copy()
total_max_len = train_max_len + val_max_len + test_SP_TM_max_len + test_TM_max_len + test_BETA_max_len
total_dismatch_index_pred = train_dismatch_index_pred.copy()
total_dismatch_index_type = train_dismatch_index_type.copy()
total_real_node_label = train_real_node_label.copy()
frames = [df_train, df_val, df_test_SP_TM, df_test_TM, df_test_BETA]
total_df = pd.concat(frames)
data_list=['val_data','test_SP_TM_data','test_TM_data','test_BETA_data']
for name in data_list:
  for i in range(0, len(name)):
    total_data.append(val_data[i])
    total_label.append(val_lable[i])
    total_batchname.append(val_batchname[i])
    total_dismatch_index_pred[list(val_dismatch_index_pred)[i]] = list(val_dismatch_index_pred.values())[i]
    total_dismatch_index_type[list(val_dismatch_index_type)[i]] = list(val_dismatch_index_type.values())[i]
    total_real_node_label.append(val_real_node_label[i]) 
cv_list = ['cv0', 'cv1', 'cv2', 'cv3', 'cv4'] 
cv0_data, cv0_label, cv0_batchname, cv0_dismatch_index_pred, cv0_dismatch_index_type, cv0_real_node_label, cv0_df=cv_data_generator(cv0,cv_list,total_df,total_data,total_label,total_batchname,total_dismatch_index_pred,total_dismatch_index_type,total_real_node_label)
cv1_data, cv1_label, cv1_batchname, cv1_dismatch_index_pred, cv1_dismatch_index_type, cv1_real_node_label, cv1_df=cv_data_generator(cv1,cv_list,total_df,total_data,total_label,total_batchname,total_dismatch_index_pred,total_dismatch_index_type,total_real_node_label)
cv2_data, cv2_label, cv2_batchname, cv2_dismatch_index_pred, cv2_dismatch_index_type, cv2_real_node_label, cv2_df=cv_data_generator(cv2,cv_list,total_df,total_data,total_label,total_batchname,total_dismatch_index_pred,total_dismatch_index_type,total_real_node_label)
cv3_data, cv3_label, cv3_batchname, cv3_dismatch_index_pred, cv3_dismatch_index_type, cv3_real_node_label, cv3_df=cv_data_generator(cv3,cv_list,total_df,total_data,total_label,total_batchname,total_dismatch_index_pred,total_dismatch_index_type,total_real_node_label)
setup1_test_data, setup1_test_label, setup1_test_batchname, setup1_test_dismatch_index_pred, setup1_test_dismatch_index_type, setup1_test_real_node_label, setup1_test_df=cv_data_generator(cv4,cv_list,total_df,total_data,total_label,total_batchname,total_dismatch_index_pred,total_dismatch_index_type,total_real_node_label)
# CV setup 1
#cv0, cv1, cv2 for train, cv3 for validation, cv4 for test
setup1_train_data = cv0_data.copy()
setup1_train_label = cv0_label.copy()
setup1_train_batchname = cv0_batchname.copy()
setup1_train_dismatch_index_pred = cv0_dismatch_index_pred.copy()
setup1_train_dismatch_index_type = cv0_dismatch_index_type.copy()
setup1_train_real_node_label = cv0_real_node_label.copy()
setup1_train_df = [cv0_df, cv1_df, cv2_df]
setup1_train_df = pd.concat(setup1_train_df)
cv_train_list=['cv1','cv2']
for name in cv_train_list:
  for i in range(0, len(cv1_data)):
    setup1_train_data.append(cv1_data[i])
    setup1_train_label.append(cv1_label[i])
    setup1_train_batchname.append(cv1_batchname[i])
    setup1_train_dismatch_index_pred[list(cv1_dismatch_index_pred)[i]] = list(cv1_dismatch_index_pred.values())[i]
    setup1_train_dismatch_index_type[list(cv1_dismatch_index_type)[i]] = list(cv1_dismatch_index_type.values())[i]
    setup1_train_real_node_label.append(cv1_real_node_label[i])

device = torch.device('cuda')
# # put model to GPU
model = SchNetModel(hidden_channels=128, out_dim=6, max_len=30000, max_num_neighbors=16).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0004,weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.05) # Learning schedule added

In [None]:
total_epochs=100
draw_num = 1
global_step = 0
# setup 1
# cv0, cv1, cv2 for train, cv3 for validation, cv4 for test
# the training data is called setup1_train_data
early_stopper = EarlyStopper(patience=5, min_delta=0.001) # se min uprise
epoch_atom_level_accuracy_record_train = []
epoch_loss_record_train=[]
epoch_node_level_accuracy_record_train = []
epoch_atom_level_accuracy_record_val = []
epoch_loss_record_val = []
epoch_node_level_accuracy_record_val = []
epochs = []
for epoch in range(total_epochs):
    epochs.append(epoch)
    epoch_atom_level_accuracy_train = []
    epoch_loss_train=[]
    epoch_node_level_accuracy_train = []
    # train
    for i, data in enumerate(setup1_train_data):  
        global_step += 1 
        optimizer.zero_grad()  
        outputs = model(data.to(device))   # put batch data in GPU get logits
        prediction = outputs["node_embedding"]  
        real_label = torch.argmax(torch.tensor(setup1_train_label[i]), dim=1).to(device) # put label in GPU  
           
        # Apply Gaussian smoothing before backpropagation
        prediction_Gauss = torch.from_numpy(gaussian_filter1d(prediction.cpu().detach().numpy(), 1, radius=3)).clone().detach().requires_grad_(True)
        
        loss = criterion(prediction_Gauss.to(device), real_label)
        loss.backward()     
        optimizer.step()    
        #calulate atom-level accuracy and node-level accuracy
        _, predicted = torch.max(prediction_Gauss.to(device), 1)
        correct = (predicted == real_label).sum().item()
        total = real_label.size(0)
        atom_level_accuracy =  correct / total
        # below is operated under CPU node
        processor = MapAtomNode(predicted.cpu(),setup1_train_batchname[i],setup1_train_dismatch_index_pred,setup1_train_dismatch_index_type,setup1_train_df)
        train_predict_node_label = processor.map_atom_node() 
        node_level_accuracy = node_accuracy(train_predict_node_label,setup1_train_real_node_label[i])
        epoch_loss_train.append(loss.item())
        epoch_atom_level_accuracy_train.append(atom_level_accuracy)
        epoch_node_level_accuracy_train.append(node_level_accuracy)    
    epoch_loss_record_train.append(np.mean(epoch_loss_train))
    epoch_atom_level_accuracy_record_train.append(np.mean(epoch_atom_level_accuracy_train))
    epoch_node_level_accuracy_record_train.append(np.mean(epoch_node_level_accuracy_train))    
    # val
    model.eval()  
    with torch.no_grad():  
        epoch_atom_level_accuracy_val = []
        epoch_loss_val = []
        epoch_node_level_accuracy_val = []
        for i, data in enumerate(cv3_data):  
            outputs = model(data.to(device))
            prediction = outputs["node_embedding"]
            real_label = torch.argmax(torch.tensor(cv3_label[i]), dim=1).to(device)
            # Apply Gaussian smoothing before back propagation
            prediction_Gauss = torch.from_numpy(gaussian_filter1d(prediction.cpu().detach().numpy(), 1, radius=3)).clone().detach().requires_grad_(True)
            loss = criterion(prediction_Gauss.to(device), real_label)
            #_, predicted = torch.max(prediction, 1)
            _, predicted = torch.max(prediction_Gauss.to(device), 1)
            correct = (predicted == real_label).sum().item()
            total = real_label.size(0)
            atom_level_accuracy = correct / total
            processor = MapAtomNode(predicted.cpu(), cv3_batchname[i], cv3_dismatch_index_pred, cv3_dismatch_index_type, cv3_df)
            val_predict_node_label = processor.map_atom_node()
            node_level_accuracy = node_accuracy(val_predict_node_label, cv3_real_node_label[i])
            epoch_loss_val.append(loss.item())
            epoch_atom_level_accuracy_val.append(atom_level_accuracy)
            epoch_node_level_accuracy_val.append(node_level_accuracy)
            epoch_loss_val.append(loss.item())
            epoch_atom_level_accuracy_val.append(atom_level_accuracy)
            epoch_node_level_accuracy_val.append(node_level_accuracy)
        epoch_loss_record_val.append(np.mean(epoch_loss_val))
        epoch_atom_level_accuracy_record_val.append(np.mean(epoch_atom_level_accuracy_val))
        epoch_node_level_accuracy_record_val.append(np.mean(epoch_node_level_accuracy_val))
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        if early_stopper.early_stop(np.mean(epoch_loss_val)):             
            break
        if global_step >= 55000:
            scheduler.step() # apply learning schedule
torch.save(model.state_dict(), '/work3/s194408/Project/result/CV_setup1_5_neighbors_Gauss.pth')

##### **$\alpha$-carbon(1-fold cross_validation)**
replace MapAtomNode part with follow coding to achieve the methond of $\alpha$-carbon


In [None]:
#node-level accuracy
j=0
CA_pred_all=[]
for k in range(len(setup_val_total_atoms_length[i])):
    index_last = int(setup_val_total_atoms_length[i][k]) + int(j)
    part_pred = predicted[j:index_last]
    CA_pred = [part_pred[index] for index in setup_val_CA_index_list[i][k]]
    CA_pred_all.extend(CA_pred)
    j = setup_val_total_atoms_length[i][k]

tensor_label = torch.tensor(setup_val_real_node_label[i], dtype=torch.float32).to(device)
CA_pred_all= [t.unsqueeze(0) for t in CA_pred_all]
CA_pred_all = torch.cat(CA_pred_all, dim=0)
node_correct = (CA_pred_all == tensor_label).sum().item()
node_total = CA_pred_all.size(0)
node_level_accuracy =  node_correct / node_total

#### **Model Test and Result**

This part is about **Major Voting(5-fold cross-validation)**,the result of **$\alpha$-carbon(1-fold cross_validation)** represnts in the report


##### **1.Topological evaluation**

The evaluation criteria for two identical protein topologies are defined according to[4] as follows:

- The N-terminal topology must be the same.
- The predicted labels and the ground truth labels must overlap with 5 residues for α-helices and 2 for β-strands (sub-parts of β-barrel).


The test results are gathered, so the topological predictions are evaluated using the above-mentioned criteria for the protein types alpa TM, alpa SP+TM, and beta barrel. 
In ***Left Table***, it can be seen that only very few topologies have been predicted for *Globular* and *SP+Globular* by the trained SchNet models. However, the average correctly predicted residues pr. protein is still much higher for SchNet than for the baseline. The average correctly predicted residues pr. protein is also calculated from DeepTMHMM predictions for reference.Please notice that the overlap criterion is set to be 5 by default for the topological evaluation of Globular and SP+Globular as these proteins contain both α-helices and β-sheetsit can be observed that neither the trained SchNet models nor the baseline have predicted any topology. However, the total numbers of matched residue predictions are still much higher for SchNet than for the baseline.


##### **2.Comparison between baseline and SchNet using McNemars test**

To assess the validity of our model, we conducted a comparison between the baseline and our trained SchNet models using the McNemar's test. ***Right Table*** provides the comparison results for each of the 5-fold cross-validation (CV) setups. It is evident that the value 0 consistently falls outside the 95% confidence intervals (CIs), and the p-values are consistently highly significant. This means that the
SchNet models are better for the topological prediction than the baseline

Finally, it is noteworthy that we initially used a hold-out approach to train and validate the GNN model on only Globular and SP+Globular proteins and tested on the remaining three protein types. However, the result was even worse





|<img src="image/table1.png" alt="Image 1" width="400" height="200"> | <img src="image/table3.png" alt="Image 2" width="400" height="200">|



we just show that how to run *setup1* test in the follow code, from setup2 test to setup5 test are the same as setup1 test

In [None]:
evice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SchNetModel(hidden_channels=128, out_dim=6, max_len=30000, max_num_neighbors=32).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003,weight_decay=1e-4)
check_point = torch.load('/content/drive/MyDrive/02456 Deep learning/Model evaluation/schnet_majorvoting/models/CV_setup1_Gauss_result.pth', map_location=torch.device('cuda'))
model.load_state_dict(check_point)
setup1_test_predict_node_label_lis = []
check_zero = []
test_node_acc_list = []
test_node_acc_binary_list = []
test_atom_acc_list = []
test_atom_correct = []
test_atom_total = []
baseline_atom_correct = []
baseline_atom_acc_list = []
baseline_node_acc_list = []
model.eval()
with torch.no_grad():
  for i, data in enumerate(setup1_test_data):
      torch.no_grad()
      outputs = model(data.to(device))
      prediction = outputs["node_embedding"]
      real_label = torch.argmax(torch.tensor(setup1_test_label[i]), dim=1).to(device)
      smoothing = GaussianSmoothing(6, 29, 5, 1)
      predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
      predicted = F.pad(predicted, (14, 14), mode='reflect')
      predicted = smoothing(predicted)
      prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))
      loss = criterion(prediction_Gauss.to(device), real_label)
      _, predicted = torch.max(prediction_Gauss.to(device), 1)
      correct = (predicted == real_label).sum().item()
      total = real_label.size(0)
      atom_level_accuracy = correct / total
      test_atom_acc_list.append(atom_level_accuracy)
      test_atom_correct.append(correct)
      test_atom_total.append(total)
      # Baseline atom accuracy
      baseline_atom = torch.zeros_like(real_label) # note that in this case the most frequent class is class 0
      baseline_correct = (baseline_atom == real_label).sum().item()
      baseline_atom_correct.append(baseline_correct)
      processor = MapAtomNode(predicted.cpu(), setup1_test_batchname[i], setup1_test_dismatch_index_pred, setup1_test_dismatch_index_type, setup1_test_df)
      test_predict_node_label = processor.map_atom_node()
      setup1_test_predict_node_label_lis.append(test_predict_node_label)
      accuracy_list = [1 if x == y else 0 for x, y in zip(test_predict_node_label, setup1_test_real_node_label[i])]
      test_node_acc_binary_list += accuracy_list
      node_level_accuracy = node_accuracy(test_predict_node_label, setup1_test_real_node_label[i])
      test_node_acc_list.append(node_level_accuracy)

      # Baseline node accuracy
      baseline_node_label = np.zeros_like(np.array(setup1_test_real_node_label[i])) # note that in this case the most frequent class is class 0
      baseline_accuracy = [1 if x == y else 0 for x, y in zip(baseline_node_label, setup1_test_real_node_label[i])]
      baseline_node_acc_list.append(baseline_accuracy)

In [None]:
from metrics_utils import label_list_to_topology,is_topologies_equal
print(len(setup1_test_predict_node_label_lis))
print(len(setup1_test_real_node_label))
final_node_acc = sum(test_node_acc_list)/len(test_node_acc_list)
print("Node acc:", final_node_acc)
final_node_binary_acc = sum(test_node_acc_binary_list)/len(test_node_acc_binary_list)
print("Node binary acc:", final_node_binary_acc)
final_atom_acc = sum(test_atom_acc_list)/len(test_atom_acc_list)
print("Avg atom acc:", final_atom_acc)
total_atom_acc = sum(test_atom_correct)/len(test_atom_total)
print("Total atom acc:", final_atom_acc)
test_resul = []
test_baseline = []
for i in range(0, len(setup1_test_predict_node_label_lis)):
  topo_A = label_list_to_topology(setup1_test_predict_node_label_lis[i])
  topo_B = label_list_to_topology(setup1_test_real_node_label[i])
  topo_baseline = baseline_node_acc_list[i]
  test_resul.append(is_topologies_equal(topo_A, topo_B, 5))
  test_baseline.append(is_topologies_equal(topo_A, topo_baseline, 5))
print("Correct topology:", sum(test_resul)/len(test_resul))
print("Correct topology baseline:", sum(test_baseline)/len(test_baseline))
baseline_all = []
for i in range(0, len(baseline_node_acc_list)):
  baseline_all += baseline_node_acc_list[i]
final_node_acc = sum(baseline_all )/len(baseline_all)
print("Baseline Node acc:", final_node_acc)
final_atom_acc = sum(baseline_atom_correct)/sum(test_atom_total)
print("Baseline Atom acc:", final_atom_acc)

In [None]:
from statsmodels.stats.contingency_tables import mcnemar
# Compute the Jeffreys interval
baseline_node_acc_array = np.array(baseline_all)
test_node_acc_array = np.array(test_node_acc_binary_list)
alpha = 0.05
[thetahat, CI, p] = mcnemar(np.ones_like(baseline_node_acc_array), baseline_node_acc_array, test_node_acc_array, alpha=alpha)
print("theta = theta_A-theta_B point estimate", thetahat, " CI: ", CI, "p-value", p)

###  **EGNN Model**

#### **Model Architecture**

The core advantage of EGNN (Equivariant Graph Neural Network) is its ability to maintain feature invariance and equivariance. This ensures that regardless of the input order of the graphs, the model's output remains consistent, which is particularly crucial for learning the structural characteristics of protein 3D structures. The following picture show the basic workflow.


1. **Graph Construction**
  - Node Initialization: Node features are initialized through an embedding layer (EMB). This involves transforming the raw data (the 3D structure of the protein) using `torch_geometric` into a feature representation that is suitable for neural network processing.

  - Edge Initialization: Similarly, edge features are initialized and transformed via `torch_geometric`, defining the relationships and interactions between nodes.

2. **EGNN: Node Feature Update**
  - Aggregation: Each node's features are aggregated by considering the features of adjacent nodes, facilitating the propagation of information across the nodes.

  - Transformation: The aggregated features undergo further processing through a transformation layer (TRANS) to learn complex feature representations.

  - Metric: The similarity between nodes, post feature update, is calculated and used to guide the classification task for protein transmembrane domains.

3. **EGNN: Edge Feature Update**  
  - Following the node feature update, edge features are also updated to capture the refined relationships between nodes.

4. **Query Node Label Prediction**
  - The model concludes by predicting labels for query nodes, an essential step for tasks such as protein classification.



|                    |
|:-------------------:|
| <p align="center"><img src="image/egnn.png" alt="Image" width="500" height="250"></p>|



#### **Downstream Tasks**
For our classification task predicting transmembrane proteins, we have a total of six categories. Thus, we have added two fully connected layers after the EGNN model. The activation function employed is ReLU, and the output size is set to 6, corresponding to the number of classification categories. Furthermore, we have applied Gaussian Smoothing to the output (the `GaussianSmoothing` part is included in `from task import GaussianSmoothing`).

#### **Label Alignment for Downstream Tasks**

Given that the model's input originates from the Alphafold model developed by DeepMind—which cannot predict protein structures with complete accuracy—there are disparities between the model's predicted labels (Alphafold's entire sequence of predicted atoms) and the original labels (from the actual protein atomic sequences). One of the downstream tasks is, therefore, to align these labels.

Label Alignment Strategies are as follow:

1. **First Approach**: Keep the predicted labels static and process the real labels to facilitate the backward propagation process. This step is implemented in `From data_utils import DismatchIndexPadRawData`. The primary idea is to compare the two sets of labels one by one, identifying the index corresponding to the actual label. If the actual label is greater than the predicted label, the index corresponding to the real label is deleted. Conversely, if the actual label is smaller, the value of the predicted label is inserted at the index of the actual label. The type of transmembrane protein at that position is inferred based on the indices immediately preceding and following the true label.

2. **Second Approach**: Maintain the actual labels unchanged and process the predicted labels to achieve atomic-level accuracy. Following the indices obtained from the first approach, if the actual label is larger than the predicted label, the corresponding value is inserted into the predicted label. If the actual label is smaller, the corresponding value in the predicted label is removed.

3. **Third Approach**: Based on the second strategy, the predicted labels are segmented according to the length of atoms within each amino acid (proteins are composed of amino acids, which in turn consist of multiple atoms). After segmentation, a majority voting technique is used to determine the category of each atom within the predicted amino acid. The transmembrane category of the amino acid is decided by the majority vote, achieving residue-level accuracy (amino acid level). This part of the process is covered in `from task import MapAtomNode`.




#### **Model Parameter**

Due to the variable lengths of the primary structures' atomic sequences in proteins, we employed the StaticEmbedding method for embedding node features. Additionally, we set the input length to `20000 * batch size`.


| Parameter Name              | Value       | Description                               |
| --------------------------- | ----------- | ----------------------------------------- |
| `num_layers`                | 5           | Number of layers in the model.           |
| `emb_dim`                   | 128         | Embedding dimension for the model.       |
| `Activation function to use`| relu        | Activation function used in the model.   |
| `norm`                      | layer       | Normalization method used (e.g., "layer").|
| `aggr`                      | mean        | Aggregation method (e.g., "mean").       |
| `learning rate`             | 0.01        | Learning rate for training.              |
| `weight_decay`              | 1e-4        | Weight decay regularization term.        |
| `batch size`                | 1           | Batch size for training.                 |
| `pool`                      | mean        | Pooling method used (e.g., "mean").      |
| `residual`                  | True        | Whether residual connections are used.   |
| `dropout`                   | 0.1         | Dropout rate for regularization.         |



#### **Downstream Tasks**

For our classification task predicting transmembrane proteins, we have a total of six categories. Thus, we have added two fully connected layers after the EGNN model. The activation function employed is ReLU, and the output size is set to 6, corresponding to the number of classification categories. Furthermore, we have applied Gaussian Smoothing to the output (the `GaussianSmoothing` part is included in `from task import GaussianSmoothing`).

##### Label Alignment for Downstream Tasks

Given that the model's input originates from the Alphafold model developed by DeepMind—which cannot predict protein structures with complete accuracy—there are disparities between the model's predicted labels (Alphafold's entire sequence of predicted atoms) and the original labels (from the actual protein atomic sequences). One of the downstream tasks is, therefore, to align these labels.

Label Alignment Strategies are as follow:

1. **First Approach**: Keep the predicted labels static and process the real labels to facilitate the backward propagation process. This step is implemented in `From data_utils import DismatchIndexPadRawData`. The primary idea is to compare the two sets of labels one by one, identifying the index corresponding to the actual label. If the actual label is greater than the predicted label, the index corresponding to the real label is deleted. Conversely, if the actual label is smaller, the value of the predicted label is inserted at the index of the actual label. The type of transmembrane protein at that position is inferred based on the indices immediately preceding and following the true label.

2. **Second Approach**: Maintain the actual labels unchanged and process the predicted labels to achieve atomic-level accuracy. Following the indices obtained from the first approach, if the actual label is larger than the predicted label, the corresponding value is inserted into the predicted label. If the actual label is smaller, the corresponding value in the predicted label is removed.

3. **Third Approach**: Based on the second strategy, the predicted labels are segmented according to the length of atoms within each amino acid (proteins are composed of amino acids, which in turn consist of multiple atoms). After segmentation, a majority voting technique is used to determine the category of each atom within the predicted amino acid. The transmembrane category of the amino acid is decided by the majority vote, achieving residue-level accuracy (amino acid level). This part of the process is covered in `from task import MapAtomNode`.





In [None]:
from task import CreateDataBeforeBatch,TMPDataset,CreateLable,MapAtomNode,node_accuracy,GaussianSmoothing,batchdata
from data_utils import ProcessRawData,ParseStructure
from torch.utils.data import DataLoader
from test import TMPTest
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from egnnmodel import EGNNModel
import numpy as np
import wandb
import seaborn as sns

#### **Model Training and Model Validation**

<p align="center">
  <img src="image/egnn_atom.png" width="400" height="280" alt="GCPNet">
  <img src="image/egnn_residual.png" width="400" height="280" alt="Second Image">
  <img src="image/egnn_loss.png" width="400" height="280" alt="Third Image">
</p>


1. **Atom-Level Accuracy:**

    - The training accuracy at the atom level remains consistently high across all epochs, as indicated by the flat line in the first graph.
    - The validation accuracy at the atom level also shows stability
    - The closeness of the training and validation accuracy lines suggests that the model generalizes well, with no apparent overfitting or underfitting.
2. **Residual-Level Accuracy:**

    - Similar to atom-level accuracy, the training accuracy at the residue level is high and stable.
    - The validation accuracy is also quite stable, and the lines for training and validation accuracy nearly overlap, indicating good generalization at the residue level too.

3. **Loss Over Epochs:**

    - The training loss drops sharply in the initial epochs and then levels off, indicating that the model quickly learns patterns from the training data.
    - The validation loss decreases along with the training loss but also stabilizes, suggesting that there are no significant improvements in the model after a certain point.
    - The proximity of training and validation loss indicates that the model is not overfitting on the training data and performs similarly on unseen data.

    In summary, this model demonstrates stable accuracy in both training and validation at the atom and residue levels, with no significant signs of overfitting. The loss metrics confirm this, showing a good fit of the model after the initial learning phase without further improvement. The performance at both the atom and residue levels suggests that the model has generalized well from the training to the validation data.

    Further assessment is required to test the model and draw conclusions, which will be presented in the following section.


In [None]:
# initialization(split data to setup1-5/download pdb/parse pdb)
file_name = "DeepTMHMM.3line"
path='/work3/s230027/DL/codebase/'
processor = ProcessRawData(path,file_name)
processor.run() # split data and download the pdb
processor = ParseStructure(path)
processor.run() # prase pdb and store them 

In [None]:
file_name = "DeepTMHMM.3line"
path='/work3/s230027/DL/codebase/'
batch_size=100
setup = 'setup1' # choose crossvalidation (total 5)
processsor= CreateDataBeforeBatch(path)
train_data_dict_before_batch,val_data_dict_before_batch,test_data_dict_before_batch=processsor.get_data(setup)

## dataloader for processing label 
train_dataset = TMPDataset(train_data_dict_before_batch)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

val_dataset = TMPDataset(val_data_dict_before_batch)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

In [None]:
# put train label togther to increase the utilization of GPU
train_residual_level_label={}
train_atom_levl_label = {}
train_dismatch_index_pred ={}
train_dismatch_index_type ={}
for data_batch in train_data_loader:
    batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
    labelprocessor=CreateLable(batchname,data_batch,path,file_name)
    atom_level_label_dict,redidual_level_label_dict,dismatch_index_pred,dismatch_index_type,df_train,_,_=labelprocessor.labeldispatcher(setup,subset='train')
    train_atom_levl_label.update(atom_level_label_dict) 
    train_residual_level_label.update(redidual_level_label_dict) 
    train_dismatch_index_pred.update(dismatch_index_pred)
    train_dismatch_index_type.update(dismatch_index_type)
# put val label togther
val_residual_level_label={}
val_atom_levl_label = {}
val_dismatch_index_pred ={}
val_dismatch_index_type ={}
for data_batch in val_data_loader:
    batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
    labelprocessor=CreateLable(batchname,data_batch,path,file_name)
    atom_level_label_dict,redidual_level_label_dict,dismatch_index_pred,dismatch_index_type,_,df_val,_=labelprocessor.labeldispatcher(setup,subset='val')
    val_atom_levl_label.update(atom_level_label_dict) 
    val_residual_level_label.update(redidual_level_label_dict) 
    val_dismatch_index_pred.update(dismatch_index_pred)
    val_dismatch_index_type.update(dismatch_index_type)

In [None]:
# dataloader for model 
batch_size=1
train_dataset = TMPDataset(train_data_dict_before_batch)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

val_dataset = TMPDataset(val_data_dict_before_batch)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,collate_fn=lambda x: x,pin_memory=True)

In [None]:
# model 
max_len= 20000*batch_size
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EGNNModel(out_dim=6,max_len=max_len,num_layers=5,emb_dim=128,residual=True,dropout=0.1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-4)

In [None]:
total_epochs=50

global_step = 0

epoch_atom_level_accuracy_record_train = []
epoch_loss_record_train=[]
epoch_residual_level_accuracy_record_train = []
epoch_atom_level_accuracy_record_val = []
epoch_loss_record_val = []
epoch_residual_level_accuracy_record_val = []
smoothing = GaussianSmoothing(6, 29, 5)
for epoch in range(total_epochs):
     epoch_atom_level_accuracy_train = []
     epoch_loss_train=[]
     epoch_residual_level_accuracy_train = []
     # train
     for data_batch in train_data_loader:
          global_step += 1 
          batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
          label_part = [value.unsqueeze(0) for name in batchname for value in train_atom_levl_label[name].to_dense()]
          atom_levl_label = torch.cat(label_part).to(device)
          residual_level_label = [value for name in batchname for value in train_residual_level_label[name]]
          data =batchdata(data_batch) 
          optimizer.zero_grad()  
          outputs = model(data.to(device)) 
          prediction = outputs["node_embedding"] 
          predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
          predicted = F.pad(predicted, (14, 14), mode='reflect')
          predicted = smoothing(predicted)
          prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))
          loss = criterion(prediction_Gauss.to(device), atom_levl_label)
          loss.backward()
          optimizer.step() 

          #calulate atom-level accuracy and node-level accuracy
          _, predicted = torch.max(prediction_Gauss.to(device), 1) 
          correct = (predicted == atom_levl_label ).sum().item()
          total = atom_levl_label.size(0)
          atom_level_accuracy =  correct / total

          processor = MapAtomNode(predicted.cpu(),batchname,train_dismatch_index_pred,train_dismatch_index_type,df_train)
          train_predict_node_label = processor.map_atom_node() 
          residual_level_accuracy = node_accuracy(train_predict_node_label,residual_level_label)
          epoch_loss_train.append(loss.item())
          epoch_atom_level_accuracy_train.append(atom_level_accuracy)
          epoch_residual_level_accuracy_train.append(residual_level_accuracy)
     epoch_loss_record_train.append(np.mean(epoch_loss_train))
     epoch_atom_level_accuracy_record_train.append(np.mean(epoch_atom_level_accuracy_train))
     epoch_residual_level_accuracy_record_train.append(np.mean(epoch_residual_level_accuracy_train))
    

     # val
     model.eval()  
     with torch.no_grad():  
          epoch_atom_level_accuracy_val = []
          epoch_loss_val = []
          epoch_residual_level_accuracy_val = []
          for data_batch in val_data_loader:
               batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
               label_part = [value.unsqueeze(0) for name in batchname for value in val_atom_levl_label[name].to_dense()]
               atom_levl_label = torch.cat(label_part).to(device)
               residual_level_label = [value for name in batchname for value in val_residual_level_label[name]]
               data =batchdata(data_batch)
               outputs = model(data.to(device)) 
               prediction = outputs["node_embedding"] 
               predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
               predicted = F.pad(predicted, (14, 14), mode='reflect')
               predicted = smoothing(predicted)
               prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))
               loss = criterion(prediction_Gauss.to(device), atom_levl_label)
               _, predicted = torch.max(prediction_Gauss.to(device), 1) 
               correct = (predicted == atom_levl_label ).sum().item()
               total = atom_levl_label.size(0)
               atom_level_accuracy =  correct / total
               processor = MapAtomNode(predicted.cpu(),batchname,val_dismatch_index_pred,val_dismatch_index_type,df_val)
               val_predict_node_label = processor.map_atom_node() 
               residual_level_accuracy = node_accuracy(val_predict_node_label,residual_level_label)
               epoch_loss_val.append(loss.item())
               epoch_atom_level_accuracy_val.append(atom_level_accuracy)
               epoch_residual_level_accuracy_val.append(residual_level_accuracy)
          epoch_loss_record_val.append(np.mean(epoch_loss_val))
          epoch_atom_level_accuracy_record_val.append(np.mean(epoch_atom_level_accuracy_val))
          epoch_residual_level_accuracy_record_val.append(np.mean(epoch_residual_level_accuracy_val))
print("Finished training.")
torch.save(model.state_dict(), '/work3/s230027/DL/result/egnn/egnn_model_size1_epoch50.pth')

#### **Model Test and Result**

There are five types porteins in the test data, including TM,SP_TM,BETA,Gloabl,Singal. But we foucs on the TM,SP_TM,BETA and want to the correct topology. We get correct topology is 0. There are some reasons:

1. insufficient classification performance for specific protein types

2. model structure issues:

    - The model structure may be ill-suited for processing certain types of protein sequences, which could negatively impact the model's performance


In [2]:
import pickle
path = '/work3/s230027/DL/codebase/dataset/parse raw data/cv4.pickle'
with open(path, 'rb') as file:
    cv4 = pickle.load(file)

TM_name_list = cv4[cv4['protein_type'] == 'TM']['uniprot_id_low'].tolist()
BETA_name_list = cv4[cv4['protein_type'] == 'BETA']['uniprot_id_low'].tolist()
SP_TM_name_list = cv4[cv4['protein_type'] == 'SP+TM']['uniprot_id_low'].tolist()


file_name = "DeepTMHMM.3line"
path='/work3/s230027/DL/codebase/'
batch_size=100
setup = 'setup1' # choose crossvalidation (total 5)
processsor= CreateDataBeforeBatch(path)
train_data_dict_before_batch,val_data_dict_before_batch,test_data_dict_before_batch=processsor.get_data(setup)

TM_test={}
for name in TM_name_list:
    TM_test[name]=test_data_dict_before_batch[name]

BETA_test={}
for name in BETA_name_list:
    BETA_test[name]=test_data_dict_before_batch[name]

SP_TM_test={}
for name in SP_TM_name_list:
    SP_TM_test[name] = test_data_dict_before_batch[name]


file_name = "DeepTMHMM.3line"
path='/work3/s230027/DL/codebase/'
modelpath = '/work3/s230027/DL/result/egnn/egnn_model_size1_epoch100.pth'
batch_size=1

In [None]:
##TM
processor=TMPTest(TM_test,file_name,path,batch_size,5,setup='setup1',modelpath=modelpath)
processor.printresult()
##SP_TM
processor=TMPTest(SP_TM_test,file_name,path,batch_size,5,setup='setup1',modelpath=modelpath)
processor.printresult()
##BETA
processor=TMPTest(BETA_test,file_name,path,batch_size,3,setup='setup1',modelpath=modelpath)
processor.printresult()

###  **GCPNET Model**

#### **Model Architecture**


1. **Graph Definition**: The model begins by defining the input data as a graph using `torch_geometric`. This includes transforming raw data (the 3D structure of proteins) into a graph format where nodes represent atoms or residues and edges represent the connections or relationships between these nodes.

2. **Geometry-Complete Graph Convolution with GCPNet**: The model employs a series of graph convolution processes (GCP) to update the features of nodes and edges.
   - **Node Tensors \(H\)**: Node features are processed through multiple GCP layers, resulting in transformed node tensors.

   - **Edge Tensors \(E\)**: Edge features are similarly updated through GCP layers.
   
   - **Frames \(F\)**: Additional geometric information might be processed through frame messages, utilizing the relative positions or orientations of atoms in 3D space.

3. **GCPNet Convolution**: The updated node and edge tensors are passed through GCPNet convolutional layers. Features and geometric information are merged to refine the graph representation further.

4. **Output Graph \(g^L\)**: The output is a graph transformed by GCPNet, encoding features and structural information suitable for downstream tasks.



|                    |
|:-------------------:|
| <p align="center"><img src="image/gcpnet.png" alt="Image" width="500" height="250"></p> [8]|




#### **Model Parameter**

| Parameter                | Value   | Description                          |
|--------------------------|---------|--------------------------------------|
| `num_layers`             | 4       | Number of layers in the model        |
| `emb_dim`                | 64      | Embedding dimension                  |
| `node_s_emb_dim`         | 64      | Node scalar embedding dimension      |
| `node_v_emb_dim`         | 8       | Node vector embedding dimension      |
| `edge_s_emb_dim`         | 16      | Edge scalar embedding dimension      |
| `edge_v_emb_dim`         | 2       | Edge vector embedding dimension      |
| `r_max`                  | 10.0    | Maximum radius for interactions      |
| `num_rbf`                | 8       | Number of radial basis functions     |
| `activation`             | 'silu'  | Activation function                  |
| `pool`                   | 'sum'   | Pooling method                       |
| `learning rate`          | 0.01    | Learning rate for training           |
| `weight_decay`           | 1e-4    | Weight decay for regularization      |
| `batch size  `           | 1   |batch size      |

In [None]:
# model paramters 
num_layers = 4
emb_dim = 64
node_s_emb_dim = emb_dim
node_v_emb_dim = 8
edge_s_emb_dim = 16
edge_v_emb_dim = 2
r_max = 10.0
num_rbf = 8
activation = 'silu'
pool = 'sum'
module_cfg = OmegaConf.create({
    'norm_pos_diff': True,
    'scalar_gate': 0,
    'vector_gate': True,
    'scalar_nonlinearity': activation,
    'vector_nonlinearity': activation,
    'nonlinearities': [activation, activation],
    'r_max': r_max,
    'num_rbf': num_rbf,
    'bottleneck': 2,
    'vector_linear': True,
    'vector_identity': True,
    'default_bottleneck': 2,
    'predict_node_positions': True,
    'predict_node_rep': True,
    'node_positions_weight': 1.0,
    'update_positions_with_vector_sum': False,
    'enable_e3_equivariance': False,
    'pool': pool,
})
# model_cfg 
model_cfg = OmegaConf.create({
    'h_input_dim': 1,  
    'chi_input_dim': 2,     
    'e_input_dim': 9, 
    'xi_input_dim': 1, 
    'h_hidden_dim': node_s_emb_dim,
    'chi_hidden_dim': node_v_emb_dim,
    'e_hidden_dim': edge_s_emb_dim,
    'xi_hidden_dim': edge_v_emb_dim,
    'num_layers': num_layers,
    'dropout': 0.0,
})

# layer_cfg 
layer_cfg = OmegaConf.create({
    'pre_norm': False,
    'use_gcp_norm': True,
    'use_gcp_dropout': True,
    'use_scalar_message_attention': True,
    'num_feedforward_layers': 2,
    'dropout': 0.0,
    'nonlinearity_slope': 1e-2,
    'mp_cfg': {
        'edge_encoder': False,
        'edge_gate': False,
        'num_message_layers': 4,
        'message_residual': 0,
        'message_ff_multiplier': 1,
        'self_message': True,
    },
})

#### **Downstream Tasks**

The downstream tasks of GCPNET model are the same as EGNN model 

For our classification task predicting transmembrane proteins, we have a total of six categories. Thus, we have added two fully connected layers after the GCPNET model. The activation function employed is ReLU, and the output size is set to 6, corresponding to the number of classification categories. Furthermore, we have applied Gaussian Smoothing to the output (the `GaussianSmoothing` part is included in `from task import GaussianSmoothing`).

##### Label Alignment for Downstream Tasks

Given that the model's input originates from the Alphafold model developed by DeepMind—which cannot predict protein structures with complete accuracy—there are disparities between the model's predicted labels (Alphafold's entire sequence of predicted atoms) and the original labels (from the actual protein atomic sequences). One of the downstream tasks is, therefore, to align these labels.

Label Alignment Strategies as follow:

1. **First Approach**: Keep the predicted labels static and process the real labels to facilitate the backward propagation process. This step is implemented in `From data_utils import DismatchIndexPadRawData`. The primary idea is to compare the two sets of labels one by one, identifying the index corresponding to the actual label. If the actual label is greater than the predicted label, the index corresponding to the real label is deleted. Conversely, if the actual label is smaller, the value of the predicted label is inserted at the index of the actual label. The type of transmembrane protein at that position is inferred based on the indices immediately preceding and following the true label.

2. **Second Approach**: Maintain the actual labels unchanged and process the predicted labels to achieve atomic-level accuracy. Following the indices obtained from the first approach, if the actual label is larger than the predicted label, the corresponding value is inserted into the predicted label. If the actual label is smaller, the corresponding value in the predicted label is removed.

3. **Third Approach**: Based on the second strategy, the predicted labels are segmented according to the length of atoms within each amino acid (proteins are composed of amino acids, which in turn consist of multiple atoms). After segmentation, a majority voting technique is used to determine the category of each atom within the predicted amino acid. The transmembrane category of the amino acid is decided by the majority vote, achieving residue-level accuracy (amino acid level). This part of the process is covered in `from task import MapAtomNode`.





In [None]:
from task import CreateDataBeforeBatch,TMPDataset,CreateLable,MapAtomNode,node_accuracy,ProcessBatch,GaussianSmoothing
from data_utils import ProcessRawData,ParseStructure
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from gcpnet import GCPNetModel
from omegaconf import OmegaConf
import numpy as np

#### **Model Training and Model Validation**

<p align="center">
  <img src="image/gcpnet_atom.png" width="400" height="280" alt="GCPNet">
  <img src="image/gcpnet_residual.png" width="400" height="280" alt="Second Image">
  <img src="image/gcpnet_loss.png" width="400" height="280" alt="Third Image">
</p>




1. **Atom-Level Accuracy:**

    - Throughout the training process, the training accuracy at the atom level remains extremely stable with almost no fluctuation, indicating that the model consistently maintains high accuracy on the training set data.
    - The accuracy on the validation set is equally stable and very close to the training accuracy curve, suggesting that the model performs similarly on unseen data, demonstrating good generalization capability.
    - The closeness of the training and validation accuracy curves indicates that there are no signs of overfitting or underfitting at the atom level, which is ideal in machine learning models.
    - The accuracy of the model is slightly above 60.6%, and it remains consistent after about 50k steps, indicating that further training does not significantly improve accuracy.

2. **Residue-Level Accuracy:**

    - Similar to the atom level, the training accuracy at the residue level also shows a high degree of stability, suggesting that the model learns well from the data at the residue level.
    - The validation accuracy almost overlaps with the training accuracy, indicating the model also generalizes well at the residue level, with good predictive performance on unknown data.
    - The consistency of the two curves further proves the robustness and reliability of the model at the residue level.
    - The model's accuracy is slightly above 60.7%, and it remains consistent after about 50k steps, indicating that further training does not significantly improve accuracy.

3. **Loss Over Epochs:**

    - The training loss drops sharply in the initial epochs and then levels off, signifying rapid progress in the initial stages of model learning. The stable trend of the training loss indicates that the model quickly reaches a point of minimized loss and does not change significantly in subsequent training.
    - The validation loss  indicates that the model performs similarly on the validation set as it does on the training set, with low and stable loss suggesting good predictive ability on new data.
    - The close proximity of the training and validation loss curves implies that the model is not overfitting on the training data and generalizes well to unseen data.

    In summary, the GCPNet model demonstrates good stability and generalization ability at both the atom and residue levels. The loss curves indicate that the model learns quickly in the early stages of training and maintains that performance throughout. High accuracy at both the atom and residue levels indicates that the model learns well from the training data and has a good predictive ability on new data.

    Further assessment is required to test the model and draw conclusions, which will be presented in the following section.



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GCPNetModel(
    num_layers=num_layers,
    node_s_emb_dim=node_s_emb_dim,
    node_v_emb_dim=node_v_emb_dim,
    edge_s_emb_dim=edge_s_emb_dim,
    edge_v_emb_dim=edge_v_emb_dim,
    r_max=r_max,
    num_rbf=num_rbf,
    activation=activation,
    pool=pool,
    module_cfg=module_cfg,
    model_cfg=model_cfg,
    layer_cfg=layer_cfg
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-4)

In [None]:
total_epochs=20
draw_num = 1
global_step = 0

epoch_atom_level_accuracy_record_train = []
epoch_loss_record_train=[]
epoch_residual_level_accuracy_record_train = []
epoch_atom_level_accuracy_record_val = []
epoch_loss_record_val = []
epoch_residual_level_accuracy_record_val = []

smoothing = GaussianSmoothing(6, 29, 5)

for epoch in range(total_epochs):
     epoch_atom_level_accuracy_train = []
     epoch_loss_train=[]
     epoch_residual_level_accuracy_train = []
     # train
     for data_batch in train_data_loader:
          global_step += 1 
          batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
          label_part = [value.unsqueeze(0) for name in batchname for value in train_atom_levl_label[name].to_dense()]
          atom_levl_label = torch.cat(label_part).to(device)
          residual_level_label = [value for name in batchname for value in train_residual_level_label[name]]
                 
          batchprocessor = ProcessBatch()
          data = batchprocessor.batchdata(data_batch) 
          optimizer.zero_grad()  
          outputs = model(data.to(device)) 
          prediction = outputs["node_embedding"] 

          predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
          predicted = F.pad(predicted, (14, 14), mode='reflect')
          predicted = smoothing(predicted)
          prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))
          loss = criterion(prediction_Gauss.to(device), atom_levl_label)
          loss.backward()
          optimizer.step() 

          #calulate atom-level accuracy and node-level accuracy
          _, predicted = torch.max(prediction_Gauss.to(device), 1) 
          correct = (predicted == atom_levl_label).sum().item()
          total = atom_levl_label.size(0)
          atom_level_accuracy =  correct / total

          processor = MapAtomNode(predicted.cpu(),batchname,train_dismatch_index_pred,train_dismatch_index_type,df_train)
          train_predict_node_label = processor.map_atom_node() 
          residual_level_accuracy = node_accuracy(train_predict_node_label,residual_level_label)
     
          epoch_loss_train.append(loss.item())
          epoch_atom_level_accuracy_train.append(atom_level_accuracy)
          epoch_residual_level_accuracy_train.append(residual_level_accuracy)

     epoch_loss_record_train.append(np.mean(epoch_loss_train))
     epoch_atom_level_accuracy_record_train.append(np.mean(epoch_atom_level_accuracy_train))
     epoch_residual_level_accuracy_record_train.append(np.mean(epoch_residual_level_accuracy_train))

     # val
     model.eval()  
     with torch.no_grad():  

          epoch_atom_level_accuracy_val = []
          epoch_loss_val = []
          epoch_residual_level_accuracy_val = []

          for data_batch in val_data_loader:

               batchname=[data_batch[num]['name'] for num in range(len(data_batch))]
               label_part = [value.unsqueeze(0) for name in batchname for value in val_atom_levl_label[name].to_dense()]
               atom_levl_label = torch.cat(label_part).to(device)
               residual_level_label = [value for name in batchname for value in val_residual_level_label[name]]
               batchprocessor = ProcessBatch()
               data = batchprocessor.batchdata(data_batch) 

               outputs = model(data.to(device)) 
               prediction = outputs["node_embedding"] 

               predicted = torch.reshape(prediction.to('cpu'), (1,prediction.shape[1], prediction.shape[0]))
               predicted = F.pad(predicted, (14, 14), mode='reflect')
               predicted = smoothing(predicted)
               prediction_Gauss = torch.reshape(predicted, (prediction.shape[0], prediction.shape[1]))

               loss = criterion(prediction_Gauss.to(device), atom_levl_label)

               _, predicted = torch.max(prediction_Gauss.to(device), 1) 
               correct = (predicted == atom_levl_label ).sum().item()
               total = atom_levl_label.size(0)
               atom_level_accuracy =  correct / total


               processor = MapAtomNode(predicted.cpu(),batchname,val_dismatch_index_pred,val_dismatch_index_type,df_val)
               val_predict_node_label = processor.map_atom_node() 
               residual_level_accuracy = node_accuracy(val_predict_node_label,residual_level_label)

               epoch_loss_val.append(loss.item())
               epoch_atom_level_accuracy_val.append(atom_level_accuracy)
               epoch_residual_level_accuracy_val.append(residual_level_accuracy)

          epoch_loss_record_val.append(np.mean(epoch_loss_val))
          epoch_atom_level_accuracy_record_val.append(np.mean(epoch_atom_level_accuracy_val))
          epoch_residual_level_accuracy_record_val.append(np.mean(epoch_residual_level_accuracy_val))

print("Finished training.")

node_acc_results = np.concatenate([ [np.array(epoch_residual_level_accuracy_record_train)], [np.array(epoch_residual_level_accuracy_record_val)] ])
np.savetxt("/work3/s230027/DL/result/gcpnet/CVsetup1_residual_acc_results.csv", node_acc_results, delimiter=',', comments="", fmt='%s')
loss_results = np.concatenate([[np.array(epoch_loss_record_train)], [np.array(epoch_loss_record_val)] ])
np.savetxt("/work3/s230027/DL/result/gcpnet/CVsetup1_loss_results.csv", loss_results, delimiter=',', comments="", fmt='%s')

#### **Model Test and Result**


1. **Performance on TM, SP_TM, BETA Proteins:**
   - The model achieved a correct topology score of 0 for TM, SP_TM, and BETA proteins.This suggests a significant limitation in the model's ability to predict the topology for these specific protein types. It may indicate that the model has not learned the necessary features to distinguish between the correct topologies of these proteins or that the current model architecture and training are not aligned with the complexity of these protein structures.

2. **Implications for Model Improvement:**
   - The model's current performance highlights the existence of substantial room for improvement, particularly for TM, SP_TM, and BETA proteins. Potential areas of improvement may include revisiting the feature extraction process to ensure that crucial information for topology prediction is not lost.
   - Another avenue for improvement could involve incorporating additional data or utilizing data augmentation techniques to enhance the model's ability to generalize from training to unseen data.
   - Consideration of more complex or specialized model architectures may also be beneficial in capturing the intricacies of protein structure prediction.

3. **Others**
   - The test accuarcy data from the EGNN are the same as from the GCPNET model, the reason we infer is adding the same Gaussian smoothing in two model to deal with the outputs from models.






In [3]:
file_name = "DeepTMHMM.3line"
path='/work3/s230027/DL/codebase/'
modelpath='/work3/s230027/DL/result/gcpnet/CVsetup1_model_major_voting_size1_epoch20.pth'
batch_size=1

In [None]:
##TM
processor=TMPTest(TM_test,file_name,path,batch_size,5,setup='setup1',modelpath=modelpath)
processor.printresult()
##SP_TM
processor=TMPTest(SP_TM_test,file_name,path,batch_size,5,setup='setup1',modelpath=modelpath)
processor.printresult()
##BETA
processor=TMPTest(BETA_test,file_name,path,batch_size,3,setup='setup1',modelpath=modelpath)
processor.printresult()

### **Discussion**

The comparison between the trained SchNet models and the baseline model shows that SchNet models are better at determining the residue-wise topology classes. Nonetheless, the overall topological prediction is far from satisfactory. 

One of the main reasons for this could be that the model was trained from scratch without using any pre-trained weight. The current state-of-the-art models like DeepTMHMM and TMbed are based on sophisticated protein language models (pLMs), which are pre-trained on billions of protein sequences. 

Furthermore, the time and computational constraints prevented us from experimenting with other GNNs, which might be more well-suited for the topology task. Finally, the amounts of the protein types alpha TM,  alpha SP+TM and  beta barrel are way less compared to  Globular and SP+ Globular as shown in the part of Protein Types and 3D Predictions Availaility.

This data imbalance most likely has contributed to the fact that the model failed to predict the topologies for these three particular protein types. Finally, it is important to emphasize that the 3D structures predicted by Alphafold also contain uncertainties, which inevitably would limit the performance of the models that are trained upon AlphaDB.

### **Conclusion**

In this study, we have trained state-of-the-art GNNs such as SchNet to perform topological tasks. We devised the major voting method to infer a residue-wise topology by gathering the atom level topological predictions. The performance of the final SchNet models is still far away from competing with the protein models such as DeepTMHMM and TMbed. Nonetheless, with more pre-trained GNNs becoming available with respect to the topological task, the performance is certainly to be improved in the near future.

---

[1]Bruce Alberts, Dennis Bray, Karen Hopkin, Alexan-
der D Johnson, Julian Lewis, Martin Raff, Keith
Roberts, and Peter Walter, Essential cell biology, Gar-
land Science, 2015.

[2] Andreas Bernsel, H ̊akan Viklund, Aron Hennerdal, and
Arne Elofsson, “Topcons: consensus prediction of
membrane protein topology,” Nucleic acids research,
vol. 37, no. suppl 2, pp. W465–W468, 2009.

[3] L ́aszl ́o Dobson, Istv ́an Rem ́enyi, and G ́abor E Tusn ́ady,
“Cctop: a consensus constrained topology prediction
web server,” Nucleic acids research, vol. 43, no. W1,
pp. W408–W412, 2015.

[4] John Jumper, Richard Evans, Alexander Pritzel, Tim
Green, Michael Figurnov, Olaf Ronneberger, Kathryn
Tunyasuvunakool, Russ Bates, Augustin ˇZ ́ıdek, Anna
Potapenko, et al., “Highly accurate protein structure pre-
diction with alphafold,” Nature, vol. 596, no. 7873, pp.
583–589, 2021.

[5] Tue Herlau, Mikkel N Schmidt, and Morten Mørup, “In-
troduction to machine learning and data mining,” Lec-
ture notes of the course of the same name given at DTU
(Technical University of Denmark), 2022.


[6] Schütt, K. T., Sauceda, H. E., Kindermans, P.-J., Tkatchenko, A., & Müller, K.-R. (2018). SchNet – A deep learning architecture for molecules and materials. Journal of Chemical Physics, 148, 241722. https://doi.org/10.1063/1.5019779



[7] Anonymous, “Evaluating representation learning on the
protein structure universe,” in Submitted to The Twelfth
International Conference on Learning Representations,
2023, under review.

[8]Airas, J., Ding, X., & Zhang, B. (2023). Transferable Implicit Solvation via Contrastive Learning of Graph Neural Networks. 
ACS Central Science. Advance online publication. https://doi.org/10.1021/acscentsci.3c01160
