In [1]:
from typing import Dict, Tuple
import torch
import numpy as np
from utils import format_pdb
from torch.utils.data import Dataset
import json
import networkx as nx
import pandas as pd
from torch_geometric.data import Data
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
embeddings = torch.load("/home/gathenes/paragraph_benchmark/241026_expanded/test_set/embeddings.pt", weights_only=True)
with open("/home/gathenes/paragraph_benchmark/241026_expanded/test_set/dict.json") as f :
    dataset_dict = json.load(f)
test_csv = pd.read_csv("/home/gathenes/paragraph_benchmark/expanded_dataset/test_set.csv")


In [3]:
class AminoAcidGraphEGNN(Dataset):
    def __init__(self, dataset_dict: Dict, residue_embeddings: torch.Tensor, csv: pd.DataFrame, alpha: str = "4.5"):
        self.dataset_dict = dataset_dict
        self.residue_embeddings = residue_embeddings
        self.alpha = alpha
        self.csv = csv

    def __len__(self):
        return self.residue_embeddings.shape[0]

    def __getitem__(self, index):
        # Load pdb code and chains
        pdb = self.dataset_dict[str(index)]["pdb_code"]
        pdb_path = f"/home/gathenes/all_structures/imgt/{pdb}.pdb"
        heavy_chain = self.csv.query("pdb == @pdb")["Hchain"].values[0]
        light_chain = self.csv.query("pdb == @pdb")["Lchain"].values[0]
        chains = [heavy_chain, light_chain]

        # Filter the PDB data
        df_pdb = format_pdb(pdb_path).query("Atom_Name == 'CA' and Chain.isin(@chains)")
        df_pdb["IMGT"] = df_pdb["Res_Num"].str.replace(r'[a-zA-Z]$', '', regex=True).astype(int)
        # Separate heavy and light chains
        heavy_df = df_pdb.query("Chain == @heavy_chain")
        light_df = df_pdb.query("Chain == @light_chain")
        print("-"*100)

        heavy_res_dict = {res_num: idx for idx, res_num in enumerate(heavy_df["Res_Num"])}
        light_res_dict = {res_num: idx for idx, res_num in enumerate(light_df["Res_Num"])}

        cdrs = list(range(25, 40 + 1)) + list(range(54, 67 + 1)) + list(range(103, 119 + 1))

        heavy_cdr = heavy_df.query("IMGT in @cdrs")
        print(heavy_cdr)
        light_cdr = light_df.query("IMGT in @cdrs")

        # Prepare labels and embeddings
        labels_heavy = torch.tensor(self.dataset_dict[str(index)][f"H_id labels {self.alpha}"], dtype=torch.float32)
        labels_light = torch.tensor(self.dataset_dict[str(index)][f"L_id labels {self.alpha}"], dtype=torch.float32)
        embedding = self.residue_embeddings[index]

        # Collect features (feats) and coordinates (coors)
        node_features = []
        node_coords = []
        node_labels = []
        # Heavy chain features and coordinates
        for i, res in enumerate(heavy_cdr["Res_Num"].tolist()):
            res_index = heavy_res_dict[res]
            node_features.append(embedding[1 + res_index][2048:])  # Assuming embeddings are (num_nodes, 2048)
            node_coords.append(heavy_cdr.iloc[i][["x", "y", "z"]].astype(float).values)
            node_labels.append(labels_heavy[res_index])

        # Light chain features and coordinates
        for i, res in enumerate(light_cdr["Res_Num"].tolist()):
            res_index = light_res_dict[res]
            node_features.append(embedding[len(labels_heavy) + 2 + res_index][2048:])
            node_coords.append(light_cdr.iloc[i][["x", "y", "z"]].astype(float).values)
            node_labels.append(labels_light[res_index])

        # Convert features and coordinates to tensors and reshape as required
        feats = torch.stack(node_features)  # Shape: (1, num_samples, num_feats)
        labels = torch.stack(node_labels)
        node_coords=np.array(node_coords)
        coors = torch.tensor(node_coords, dtype=torch.float32)  # Shape: (1, num_samples, 3)

        # Calculate pairwise distances for edges
        antibody_coords = np.array(node_coords)
        distances = np.linalg.norm(antibody_coords[:, np.newaxis] - antibody_coords, axis=2)

        # Generate the edges tensor as specified
        edges = torch.tensor(distances, dtype=torch.float32).unsqueeze(-1)  # Shape: (1, num_samples, num_samples, 1)

        return (feats, coors, edges),labels


In [16]:
dataset = AminoAcidGraphEGNN(dataset_dict=dataset_dict, residue_embeddings=embeddings, csv=test_csv)


In [17]:
from graph_model import EGNN_Model
from torch.utils.data import DataLoader
model = EGNN_Model(num_feats = 480,graph_hidden_layer_output_dims = [480]*6,linear_hidden_layer_output_dims = [10]*2)
dl=DataLoader(dataset, shuffle=True, batch_size=1)


In [20]:
import torch.nn as nn
criterion =nn.BCELoss()
for i,((feats, coors, edges),labels) in enumerate(dl):
    if i==0:
        break


----------------------------------------------------------------------------------------------------
     Res_Num        x        y        z
852        1  -25.073  -10.125   28.509
861        2  -21.689   -8.487   29.109
868        3  -19.617   -6.048   27.062
877        4  -15.873   -5.450   27.060
885        5  -14.336   -2.233   25.771
...      ...      ...      ...      ...
2514     224  -20.594   15.715   -0.402
2523     225  -19.889   14.090   -3.768
2531     226  -20.921   16.349   -6.664
2538     227  -20.948   15.780  -10.477
2545     228  -18.058   17.025  -12.625

[223 rows x 4 columns]
----------------------------------------------------------------------------------------------------
tensor([[ 0.0921,  0.1575,  0.1343,  ...,  0.4358,  0.1856, -0.2667],
        [-0.1375,  0.2959, -0.0080,  ...,  0.2209, -0.5704, -0.3282],
        [-0.0821,  0.0025,  0.1603,  ...,  0.0450, -0.3970,  0.2823],
        ...,
        [ 0.0764, -0.0536,  0.0724,  ...,  0.0000,  0.0000,  0.0000],
 

In [63]:
import torch.optim as optim
from torch_geometric.data import DataLoader
from tqdm import tqdm
torch.nn as nn
# Assuming 'dataset' is your graph dataset
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion =nn.BCELoss()

def train(model, dataloader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in tqdm(dataloader):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            out = torch.sigmoid(out.squeeze(-1))
            loss = criterion(out, data.y)  # Assuming `data.y` are node labels
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train(model, dataloader, optimizer, criterion)


SyntaxError: invalid syntax (1787613487.py, line 4)

# ideas put aside

In [None]:
from torch_geometric.data import Data
import networkx as nx

class AminoAcidGraphSimple(Dataset):
    def __init__(
        self,
        dataset_dict: Dict,
        residue_embeddings: torch.Tensor,
        csv: pd.DataFrame,
        alpha: str = "4.5",
    ):
        self.dataset_dict = dataset_dict
        self.residue_embeddings = residue_embeddings
        self.alpha = alpha
        self.csv = csv

    def __len__(self):
        return self.residue_embeddings.shape[0]

    def __getitem__(self, index):
        pdb = self.dataset_dict[str(index)]["pdb_code"]
        pdb_path = f"/home/gathenes/all_structures/imgt/{pdb}.pdb"
        heavy_chain = self.csv.query("pdb == @pdb")["Hchain"].values[0]
        light_chain = self.csv.query("pdb == @pdb")["Lchain"].values[0]
        chains = [heavy_chain, light_chain]

        # Filter the PDB data
        df_pdb = format_pdb(pdb_path).query("Atom_Name == 'CA' and Chain.isin(@chains)")
        # Separate heavy and light chains
        heavy_df = df_pdb.query("Chain == @heavy_chain")
        light_df = df_pdb.query("Chain == @light_chain")
        heavy_res_dict = {
            res_num: idx for idx, res_num in enumerate(heavy_df["Res_Num"])
        }
        light_res_dict = {
            res_num: idx for idx, res_num in enumerate(light_df["Res_Num"])
        }

        df_pdb = pd.concat([heavy_df, light_df])
        df_pdb["IMGT"] = (
            df_pdb["Res_Num"].str.replace(r"[a-zA-Z]$", "", regex=True).astype(int)
        )
        cdrs = (
            list(range(25, 40 + 1))
            + list(range(54, 67 + 1))
            + list(range(103, 119 + 1))
        )
        df_pdb = df_pdb.query("IMGT in @cdrs")

        # Prepare labels and embeddings
        labels_heavy = torch.tensor(
            self.dataset_dict[str(index)][f"H_id labels {self.alpha}"],
            dtype=torch.float32,
        )
        labels_light = torch.tensor(
            self.dataset_dict[str(index)][f"L_id labels {self.alpha}"],
            dtype=torch.float32,
        )
        embedding = self.residue_embeddings[index]

        # Collect features and labels for the graph nodes
        node_features = []
        node_labels = []

        # Heavy chain nodes
        for i, res in enumerate(heavy_df["Res_Num"].tolist()):
            res_index = heavy_res_dict[res]
            node_features.append(
                embedding[1 + res_index][2048:]
            )  # Assuming embeddings are (num_nodes, 2048)
            node_labels.append(labels_heavy[res_index])

        # Light chain nodes
        for i, res in enumerate(light_df["Res_Num"].tolist()):
            res_index = light_res_dict[res]
            node_features.append(embedding[len(labels_heavy) + 2 + res_index][2048:])
            node_labels.append(labels_light[res_index])

        # Convert features and labels to tensors
        x = torch.stack(node_features)  # Shape: (num_nodes, 1024)
        y = torch.stack(node_labels)  # Shape: (num_nodes, 3)

        # Create edges based on 3D distances
        antibody_coords = df_pdb[["x", "y", "z"]].astype(float).values
        distances = np.linalg.norm(
            antibody_coords[:, np.newaxis] - antibody_coords, axis=2
        )
        antibody_indices, neighbor_indices = np.where(
            (distances < 10) & (distances > 0)
        )
        edges = [(i, j) for i, j in zip(antibody_indices, neighbor_indices)]

        # Prepare edge index for PyTorch Geometric
        edge_index = (
            torch.tensor(edges, dtype=torch.long).t().contiguous()
        )  # Shape: (2, num_edges)

        # Return graph data and labels
        torch_graph = Data(x=x, edge_index=edge_index, y=y)
        return torch_graph
