### 버전 기록
- Model Baseline 코드 (2024.12.30)
- Drug Graph, Drug Graph Embedding Block 추가 (2025.01.01)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from itertools import product
import numpy as np

#
import random
from torch_geometric.data import Data

In [3]:
# # num_samples : NUM_CELL_LINES(1280) * NUM_DRUGS (193)
num_pathways = 245
num_genes = 231
num_drugs = 83
num_substructures = 201

In [23]:
# num_cell_lines = 10  # Reduced number for simplicity
# num_pathways = 5
# num_genes = 15
# num_drugs = 7  # Reduced number for simplicity
# num_substructures = 10

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"CUDA is available: {torch.cuda.is_available()}")

CUDA is available: True


## 데이터

### (1) Gene Embedding Dictionary 
Cell Line → gene embeddings (num_pathways, num_genes)

In [6]:
saved_embeddings = np.load('../0. input/gene_embeddings.npz')
gene_embeddings = {
    key: torch.tensor(saved_embeddings[key], dtype=torch.float32)
    for key in saved_embeddings.keys()
}
saved_embeddings.close()

# print(gene_embeddings['DATA.683665'].shape)
len(gene_embeddings)

# gene_embeddings = {
#     f"DATA.{i}": torch.rand(num_pathways, num_genes, dtype=torch.float32)
#     for i in range(num_cell_lines)
# }

# gene_embeddings['DATA.1'].shape


1280

### (2) Drug Embedding Dictionary
CID (str) → drug embeddings (num_substructures)

In [7]:
saved_embeddings = np.load('../0. input/0_drug_embeddings.npz')
drug_embeddings = {
    key: torch.tensor(saved_embeddings[key], dtype=torch.float32)
    for key in saved_embeddings.keys()
}
saved_embeddings.close()

len(drug_embeddings)

# drug_embeddings = {
#     f"{i}": torch.rand(num_substructures, dtype=torch.float32)
#     for i in range(num_drugs)
# }

# drug_embeddings['1'].shape


83

### (3) Drug Graph Dictionary
CID → drug graph (pytorch geomeric data)

In [8]:
drug_graph_dict = torch.load('../0. input/0_drug_graph_dict.pt')

print(len(drug_graph_dict))

# # Create dummy data for drug_graphs
# drug_graph_dict = {
#     f"{i}": Data(
#         x=torch.rand(num_substructures, 5),  # Node features
#         edge_index=torch.randint(0, num_substructures, (2, 12)),  # Random edges
#         global_ids=list(range(num_substructures))
#     )
#     for i in range(num_drugs)
# }

# print(drug_graph_dict['1'])


83


  drug_graph_dict = torch.load('../0. input/0_drug_graph_dict.pt')


In [9]:
data = torch.load('data.pt')
# gene_embeddings = data['gene_embeddings']  # Shape: [1280, 245, 231]
gene_adjacencies = data['gene_adjacencies']  # Shape: [245, 231, 231]
# substructure_embeddings = data['substructure_embeddings']  # Shape: [69, 245, 193] (drug_num x num_pathways x num_substructures)

# gene_adjacencies = torch.rand(num_pathways, num_genes, num_genes)
# substructure_embeddings = {
#     drug: torch.rand(10) for drug in drug_embeddings.keys()
# }

# 모든 cell line과 drug의 조합 생성
cell_lines = list(gene_embeddings.keys())
drugs = list(drug_embeddings.keys())
sample_indices = list(product(cell_lines, drugs)) 

# Cell line과 drug의 매핑
cell_line_mapping = {key: idx for idx, key in enumerate(gene_embeddings.keys())}
drug_mapping = {cid: idx for idx, cid in enumerate(drug_graph_dict.keys())}

  data = torch.load('data.pt')


### (4) Labels

In [14]:
labels_dict = torch.load('../0. input/0_drug_label_dict.pt')

print(len(labels_dict))

106240


  labels_dict = torch.load('../0. input/0_drug_label_dict.pt')


## Dataset

In [15]:
class DrugResponseDataset(Dataset):
    def __init__(self, gene_embeddings, gene_adjacencies, substructure_embeddings, drug_graphs, labels, sample_indices):
        """
        Args:
            gene_embeddings (dict): {cell_line_id: Tensor}, Gene embeddings for each cell line.
            drug_graphs (dict): List of PyTorch Geometric Data objects for each drug (indexed by drug_id).
            substructure_embeddings (Tensor): [245, 193], Substructure embeddings for pathways.
            labels (dict): {cell_line_id: Tensor}, Drug response labels for each cell line and drug pair.
            sample_indices (list): [(cell_line_id, drug_idx)], List of cell line and drug index pairs.
        """
        self.gene_embeddings = gene_embeddings  # {cell_line_id: [245, 231]}
        self.gene_adjacencies = gene_adjacencies
        self.drug_graphs = drug_graphs  # Drug graphs
        self.substructure_embeddings = substructure_embeddings  # [170]
        self.labels = labels  # {cell_line_id, drug_id : [1]}
        self.sample_indices = sample_indices  # [(cell_line_id, drug_id)]


    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        cell_line_id, drug_id = self.sample_indices[idx]

        # Gene embeddings for the cell line
        gene_embedding = self.gene_embeddings[cell_line_id]  # [245, 231]

        # Substructure embeddings for pathways
        substructure_embedding = self.substructure_embeddings[drug_id].repeat(245, 1)  # [245, 170]
        drug_graph = self.drug_graphs[drug_id]  # Drug graphs

        # Get the label for the cell line-drug pair
        label = self.labels[cell_line_id, drug_id]  # Scalar

        return {
            'gene_embedding': gene_embedding,  # [245, 231]
            'gene_adj': self.gene_adjacencies,                     # 더미
            'substructure_embedding': substructure_embedding,  # [245, 170]
            'drug_graph': drug_graph,  # PyTorch Geometric Data object
            'label': label  # Scalar
        }

In [16]:
from torch_geometric.data import Batch

def collate_fn(batch):
    gene_embeddings = []
    gene_adjacencies = []
    substructure_embeddings = []
    drug_graphs = []
    labels = []

    for item in batch:
        gene_embeddings.append(item['gene_embedding'])
        gene_adjacencies.append(item['gene_adj'])
        substructure_embeddings.append(item['substructure_embedding'])
        drug_graphs.append(item['drug_graph'])
        labels.append(item['label'])

    drug_batch = Batch.from_data_list(drug_graphs)


    return {
        'gene_embeddings': torch.stack(gene_embeddings),  # [batch_size, num_pathways, num_genes]
        'gene_adjacencies': torch.stack(gene_adjacencies),  # [batch_size, num_pathways, num_genes, num_genes]
        'substructure_embeddings': torch.stack(substructure_embeddings),  # [batch_size, num_pathways, num_substructures]
        'drug_graphs': drug_batch,  # [batch_size, num_pathways, num_substructures, num_substructures]
        'labels': torch.tensor(labels, dtype=torch.float32)  # [batch_size]
    }

In [17]:
# Dataset 초기화
dataset = DrugResponseDataset(
    gene_embeddings=gene_embeddings,
    gene_adjacencies=gene_adjacencies,
    substructure_embeddings=drug_embeddings,
    drug_graphs=drug_graph_dict,
    labels=labels_dict,
    sample_indices=sample_indices,
)

# DataLoader 초기화
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [18]:
# 데이터 확인
for batch in data_loader:
    print(f"Batch gene embeddings: {batch['gene_embeddings'].shape}")  # [8, 245, 231]
    print(f"Batch gene adjacencies: {batch['gene_adjacencies'].shape}")  # [8, 245, 231, 231]
    print(f"Batch substructure embeddings: {batch['substructure_embeddings'].shape}")  # [8, 245, 193]
    print(f"Batch drug graphs: {batch['drug_graphs']}")  # [8, 245, 193, 193]
    print(f"Batch labels: {batch['labels'].shape}")  # [8]
    break

Batch gene embeddings: torch.Size([8, 245, 231])
Batch gene adjacencies: torch.Size([8, 245, 231, 231])
Batch substructure embeddings: torch.Size([8, 245, 201])
Batch drug graphs: DataBatch(x=[56, 128], edge_index=[2, 51], drug=[8], node_names=[8], global_ids=[8], batch=[56], ptr=[9])
Batch labels: torch.Size([8])


## (0) Embedding Layer
### - GeneEmbeddingLayer : FloatTensor -> Linear
### - SubstructureEmbeddingLayer : IntTensor -> nn.Embedding

In [34]:
GENE_EMBEDDING_DIM = 128
SUBSTRUCTURE_EMBEDDING_DIM = 128
HIDDEN_DIM = 128
FINAL_DIM = 64
OUTPUT_DIM = 1

In [35]:
class GeneEmbeddingLayer(nn.Module):
    # In)  [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_GENES(231)] 
    # Out) [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]

    def __init__(self, num_genes, embedding_dim=GENE_EMBEDDING_DIM):
        super(GeneEmbeddingLayer, self).__init__()
        self.linear = nn.Linear(1, embedding_dim)  
        self.num_genes = num_genes
        
    def forward(self, gene_values):
        gene_values = gene_values.view(-1, 1) # [BATCH_SIZE * NUM_PATHWAYS * NUM_GENES, 1]
        embedded_values = self.linear(gene_values)  # [BATCH_SIZE * NUM_PATHWAYS * NUM_GENES, GENE_EMBEDDING_DIM]
        return embedded_values.view(-1, num_pathways, num_genes, GENE_EMBEDDING_DIM) 

class SubstructureEmbeddingLayer(nn.Module):
    # In)  [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_SUBSTRUCTURES(193)]
    # Out) [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_SUBSTRUCTURES(193), SUBSTRUCTURES_EMBEDDING_DIM(128)]
    
    def __init__(self, num_substructures, embedding_dim=SUBSTRUCTURE_EMBEDDING_DIM):
        super(SubstructureEmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(num_substructures, embedding_dim) # [NUM_SUBSTRUCTURES, SUBSTRUCTURES_EMBEDDING_DIM]

    def forward(self, substructure_indices):
        return self.embedding(substructure_indices) 

## (1) CrossAttention

In [36]:
class CrossAttention(nn.Module):
    def __init__(self, query_dim, key_dim):
        super(CrossAttention, self).__init__()
        self.query_layer = nn.Linear(query_dim, query_dim)  
        self.key_layer = nn.Linear(key_dim, query_dim)      
        self.value_layer = nn.Linear(key_dim, query_dim)    

    def forward(self, query_embeddings, key_embeddings):
        query = self.query_layer(query_embeddings) 
        key = self.key_layer(key_embeddings)        
        value = self.value_layer(key_embeddings)   

        # Attention Scores
        attention_scores = torch.matmul(query, key.transpose(-1, -2))  
        attention_weights = F.softmax(attention_scores, dim=-1)        

        # Apply Attention
        attended_embeddings = torch.matmul(attention_weights, value)  
        
        return attended_embeddings



## (2) Graph Embedding

In [37]:
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool

class GraphEmbedding(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphEmbedding, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def forward(self, embeddings, adj_matrices):
        # adjacency matrices -> edge_index and batch tensors
        batch_graphs = []
        batch_size = embeddings.size(0)
        
        for i in range(batch_size):
            adj_matrix = adj_matrices[i]
            edge_index = adj_matrix.nonzero(as_tuple=False).t()  # [2, num_edges]
            node_features = embeddings[i]  # [num_nodes, input_dim]
            
            # Create graph data for each batch
            graph_data = Data(x=node_features, edge_index=edge_index)
            batch_graphs.append(graph_data)
        
        # Batch all graphs together
        batch = Batch.from_data_list(batch_graphs)
        
        # Apply GCN layers
        x = self.conv1(batch.x, batch.edge_index)
        x = torch.relu(x)
        x = self.conv2(x, batch.edge_index)
        
        # Global mean pooling for graph-level embedding
        graph_embedding = global_mean_pool(x, batch.batch)  # [batch_size, hidden_dim]
        return graph_embedding
    

class DrugGraphEmbedding(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(DrugGraphEmbedding, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def forward(self, drug_graphs, drug_graph_embedding):
        """
        Args:
            drug_graphs (Batch): Batched PyTorch Geometric Data object.
            drug_graph_embedding (Tensor): [BATCH_SIZE, NUM_PATHWAYS, NUM_SUBSTRUCTURES, EMBEDDING_DIM]
        """
        all_node_features = []
        updated_node_features = torch.mean(drug_graph_embedding, dim=1)  # [BATCH_SIZE(8), NUM_SUBSTRUCTURES, SUBSTRUCTURE_EMBEDDING_DIM(128)]

        # Batch Loop
        for batch_idx in range(updated_node_features.size(0)):  
            # Get global IDs and node indices
            global_ids = drug_graphs[batch_idx].global_ids  
            node_indices = torch.where(drug_graphs.batch == batch_idx)[0]

            # Ensure global_ids and node_indices match in length
            assert len(global_ids) == len(node_indices), "Mismatch between global IDs and node indices length"

            # Update node features for the current batch
            node_features = torch.zeros((len(node_indices), updated_node_features.size(-1)), device=updated_node_features.device)
            for local_idx, global_id in enumerate(global_ids):
                if global_id < updated_node_features.size(1):
                    node_features[local_idx] = updated_node_features[batch_idx, global_id]

            # Append the current batch's node features to the list
            all_node_features.append(node_features)

        # updated node features from all batches
        new_node_features = torch.cat(all_node_features, dim=0)  # Shape: [TOTAL_NUM_NODES, EMBEDDING_DIM]
        drug_graphs.x = new_node_features

        # GCN layers
        x = self.conv1(drug_graphs.x, drug_graphs.edge_index)
        x = F.relu(x)
        x = self.conv2(x, drug_graphs.edge_index)

        # Perform global mean pooling
        graph_embedding = global_mean_pool(x, drug_graphs.batch)  # Shape: [BATCH_SIZE, HIDDEN_DIM]
        return graph_embedding


## (3) DrugResponseModel

In [38]:
class DrugResponseModel(nn.Module):
    def __init__(self, num_genes, num_substructures, hidden_dim, final_dim):
        super(DrugResponseModel, self).__init__()
        # Embedding Layers
        self.gene_embedding_layer = GeneEmbeddingLayer(num_genes, GENE_EMBEDDING_DIM)
        self.substructure_embedding_layer = SubstructureEmbeddingLayer(num_substructures, SUBSTRUCTURE_EMBEDDING_DIM)
        
        # Cross Attention Layers
        self.Gene2Sub_cross_attention = CrossAttention(query_dim=GENE_EMBEDDING_DIM, key_dim=SUBSTRUCTURE_EMBEDDING_DIM)
        self.Sub2Gene_cross_attention = CrossAttention(query_dim=SUBSTRUCTURE_EMBEDDING_DIM, key_dim=GENE_EMBEDDING_DIM)

        # Graph Embedding Layers
        self.pathway_graph = GraphEmbedding(GENE_EMBEDDING_DIM, hidden_dim)
        self.drug_graph = DrugGraphEmbedding(SUBSTRUCTURE_EMBEDDING_DIM, hidden_dim)

        # Fully Connected Layers
        self.fc1 = nn.Linear(2 * hidden_dim, final_dim)
        self.fc2 = nn.Linear(final_dim, OUTPUT_DIM)
        self.sigmoid = nn.Sigmoid()

    def forward(self, gene_embeddings, gene_adjacencies, substructure_embeddings, drug_graphs):
        # Dummy code for making integer tensor
        substructure_embeddings = substructure_embeddings.int() 

        # Gene and Substructure Embeddings
        gene_embeddings = self.gene_embedding_layer(gene_embeddings)  # [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]
        substructure_embeddings = self.substructure_embedding_layer(substructure_embeddings)  # [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_SUBSTRUCTURES(193), SUBSTRUCTURE_EMBEDDING_DIM(128)]

        # Pathway and Drug Graph Embeddings List
        pathway_graph_embeddings = []
        drug_graph_embeddings = []

        # Pathway loop
        for i in range(gene_embeddings.size(1)):  
            gene_emb = gene_embeddings[:, i, :]  # [BATCH_SIZE(8), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]
            sub_emb = substructure_embeddings[:, i, :]  # [BATCH_SIZE(8), NUM_SUBSTRUCTURES(193), SUBSTRUCTURE_EMBEDDING_DIM(128)]

            # # Cross attention
            # updated_gene_emb = self.Gene2Sub_cross_attention(gene_emb, sub_emb) # [BATCH_SIZE(8), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]
            # updated_sub_emb = self.Sub2Gene_cross_attention(sub_emb, gene_emb) # [BATCH_SIZE(8), NUM_SUBSTRUCTURES(193), SUBSTRUCTURE_EMBEDDING_DIM(128)]

            # # Generate graph embeddings for each pathway
            # pathway_graph_embedding = self.pathway_graph(updated_gene_emb, gene_adjacencies[:, i, :, :]) # [BATCH_SIZE(8),GENE_EMBEDDING_DIM(128)]
            # pathway_graph_embeddings.append(pathway_graph_embedding)

            # drug_graph_embeddings.append(updated_sub_emb)

            pathway_graph_embedding = self.pathway_graph(gene_emb, gene_adjacencies[:, i, :, :]) # [BATCH_SIZE(8),GENE_EMBEDDING_DIM(128)]
            pathway_graph_embeddings.append(pathway_graph_embedding)

            drug_graph_embeddings.append(sub_emb)
        
        
        # Drug Graph Embedding for all pathways
        drug_graph_embedding = torch.stack(drug_graph_embeddings, dim=1)  # [BATCH_SIZE(8), NUM_PATHWAYS, NUM_SUBSTRUCTURES, SUBSTRUCTURE_EMBEDDING_DIM(128)]
        final_drug_embedding = self.drug_graph(drug_graphs, drug_graph_embedding)  # [BATCH_SIZE(8), HIDDEN_DIM(128)]

        # Stack pathway and drug graph embeddings
        pathway_graph_embeddings = torch.stack(pathway_graph_embeddings, dim=1) # [BATCH_SIZE(8), NUM_PATHWAYS(10), GENE_EMBEDDING_DIM(128)]
        final_pathway_embedding = torch.mean(pathway_graph_embeddings, dim=1) # [BATCH_SIZE(8), GENE_EMBEDDING_DIM(128)]

        # Concatenate final embeddings
        combined_embedding = torch.cat((final_pathway_embedding, final_drug_embedding), dim=-1)  # [BATCH_SIZE(8), GENE_EMBEDDING_DIM + SUBSTRUCTURE_EMBEDDING_DIM(256)]

        # Final Prediction
        x = self.fc1(combined_embedding) # [BATCH_SIZE(8), FINAL_DIM(64)]
        x = self.fc2(x) # [BATCH_SIZE(8), OUTPUT_DIM(1)]

        return self.sigmoid(x)
    

In [None]:
model = DrugResponseModel(num_genes, num_substructures, HIDDEN_DIM, FINAL_DIM)
model = model.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(f"Model device: {next(model.parameters()).device}")
print(f"CUDA is available: {torch.cuda.is_available()}")

In [None]:
# Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in data_loader:
        gene_embeddings = batch['gene_embeddings'].to(device)
        gene_adjacencies = batch['gene_adjacencies'].to(device)
        substructure_embeddings = batch['substructure_embeddings'].to(device)
        drug_graphs = batch['drug_graphs'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(gene_embeddings, gene_adjacencies, substructure_embeddings, drug_graphs)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(data_loader):.4f}")