In [2]:
from typing import Dict
import torch
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from torch_geometric.data import Data
from utils import read_pdb_to_dataframe


  from .autonotebook import tqdm as notebook_tqdm


# ideas put aside

In [None]:


class AminoAcidGraphSimple(Dataset):
    def __init__(
        self,
        dataset_dict: Dict,
        residue_embeddings: torch.Tensor,
        csv: pd.DataFrame,
        alpha: str = "4.5",
    ):
        self.dataset_dict = dataset_dict
        self.residue_embeddings = residue_embeddings
        self.alpha = alpha
        self.csv = csv

    def __len__(self):
        return self.residue_embeddings.shape[0]

    def __getitem__(self, index):
        pdb = self.dataset_dict[str(index)]["pdb_code"]
        pdb_path = f"/home/gathenes/all_structures/imgt/{pdb}.pdb"
        heavy_chain = self.csv.query("pdb == @pdb")["Hchain"].values[0]
        light_chain = self.csv.query("pdb == @pdb")["Lchain"].values[0]
        chains = [heavy_chain, light_chain]

        # Filter the PDB data
        df_pdb = read_pdb_to_dataframe(pdb_path).query("atom_name == 'CA' and chain_id.isin(@chains)")
        # Separate heavy and light chains
        heavy_df = df_pdb.query("chain_id == @heavy_chain")
        light_df = df_pdb.query("chain_id == @light_chain")
        heavy_res_dict = {
            residue_number: idx for idx, residue_number in enumerate(heavy_df["residue_number"])
        }
        light_res_dict = {
            residue_number: idx for idx, residue_number in enumerate(light_df["residue_number"])
        }

        df_pdb = pd.concat([heavy_df, light_df])
        cdrs = (
            list(range(25, 40 + 1))
            + list(range(54, 67 + 1))
            + list(range(103, 119 + 1))
        )
        df_pdb = df_pdb.query("residue_number in @cdrs")

        # Prepare labels and embeddings
        labels_heavy = torch.tensor(
            self.dataset_dict[str(index)][f"H_id labels {self.alpha}"],
            dtype=torch.float32,
        )
        labels_light = torch.tensor(
            self.dataset_dict[str(index)][f"L_id labels {self.alpha}"],
            dtype=torch.float32,
        )
        embedding = self.residue_embeddings[index]

        # Collect features and labels for the graph nodes
        node_features = []
        node_labels = []

        # Heavy chain nodes
        for i, res in enumerate(heavy_df["residue_number"].tolist()):
            res_index = heavy_res_dict[res]
            node_features.append(
                embedding[1 + res_index][2048:]
            )  # Assuming embeddings are (num_nodes, 2048)
            node_labels.append(labels_heavy[res_index])

        # Light chain nodes
        for i, res in enumerate(light_df["residue_number"].tolist()):
            res_index = light_res_dict[res]
            node_features.append(embedding[len(labels_heavy) + 2 + res_index][2048:])
            node_labels.append(labels_light[res_index])

        # Convert features and labels to tensors
        x = torch.stack(node_features)  # Shape: (num_nodes, 1024)
        y = torch.stack(node_labels)  # Shape: (num_nodes, 3)

        # Create edges based on 3D distances
        antibody_coords = df_pdb[["x_coord", "y_coord", "z_coord"]].astype(float).values
        distances = np.linalg.norm(
            antibody_coords[:, np.newaxis] - antibody_coords, axis=2
        )
        antibody_indices, neighbor_indices = np.where(
            (distances < 10) & (distances > 0)
        )
        edges = [(i, j) for i, j in zip(antibody_indices, neighbor_indices)]

        # Prepare edge index for PyTorch Geometric
        edge_index = (
            torch.tensor(edges, dtype=torch.long).t().contiguous()
        )  # Shape: (2, num_edges)

        # Return graph data and labels
        torch_graph = Data(x=x, edge_index=edge_index, y=y)
        return torch_graph
