### 버전 기록
- Model Baseline 코드 (2024.12.30)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from itertools import product


In [2]:
# num_samples : NUM_CELL_LINES(1280) * NUM_DRUGS (193)
num_pathways = 245
num_genes = 231
num_drugs = 69
num_substructures = 170

In [None]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(f"CUDA is available: {torch.cuda.is_available()}")

## 데이터 준비

In [None]:
data = torch.load('data.pt')

gene_embeddings = data['gene_embeddings']  # Shape: [1280, 245, 231]
gene_adjacencies = data['gene_adjacencies']  # Shape: [245, 231, 231]
substructure_embeddings = data['substructure_embeddings']  # Shape: [69, 245, 193] (drug_num x num_pathways x num_substructures)
substructure_adjacencies = data['substructure_adjacencies']  # Shape: [69, 245, 193, 193]
labels = torch.randint(0, 2, (1280, 69), dtype=torch.float32)  # Shape: [1280, 69] (cell_line_num x drug_num)


# 모든 cell line과 drug의 조합 생성
cell_line_indices = range(len(gene_embeddings))  # 1280
drug_indices = range(num_drugs)  # 69
sample_indices = list(product(cell_line_indices, drug_indices))  # [(0, 0), (0, 1), ..., (1279, 68)]

## Dataset

In [6]:
class DrugResponseDataset(Dataset):
    def __init__(self, gene_embeddings, gene_adjacencies, substructure_embeddings, substructure_adjacencies, labels, sample_indices):
        self.gene_embeddings = gene_embeddings  # [1280, 245, 231]
        self.gene_adjacencies = gene_adjacencies  # [245, 231, 231]
        self.substructure_embeddings = substructure_embeddings  # [245, 193]
        self.substructure_adjacencies = substructure_adjacencies  # [245, 193, 193]
        self.labels = labels  # [1280, 69]
        self.sample_indices = sample_indices  # [(cell_line_idx, drug_idx), ...]

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        cell_line_idx, drug_idx = self.sample_indices[idx]
        return {
            'gene_embedding': self.gene_embeddings[cell_line_idx],  # [245, 231]
            'gene_adj': self.gene_adjacencies,                     # [245, 231, 231] (공유)
            'substructure_embedding': self.substructure_embeddings,  # [245, 193] (공유)
            'substructure_adj': self.substructure_adjacencies[drug_idx],  # [245, 193, 193]
            'label': self.labels[cell_line_idx, drug_idx]           # Scalar
        }

In [7]:
def collate_fn(batch):
    gene_embeddings = []
    gene_adjacencies = []
    substructure_embeddings = []
    substructure_adjacencies = []
    labels = []

    for item in batch:
        gene_embeddings.append(item['gene_embedding'])
        gene_adjacencies.append(item['gene_adj'])
        substructure_embeddings.append(item['substructure_embedding'])
        substructure_adjacencies.append(item['substructure_adj'])
        labels.append(item['label'])

    return {
        'gene_embeddings': torch.stack(gene_embeddings),  # [batch_size, num_pathways, num_genes]
        'gene_adjacencies': torch.stack(gene_adjacencies),  # [batch_size, num_pathways, num_genes, num_genes]
        'substructure_embeddings': torch.stack(substructure_embeddings),  # [batch_size, num_pathways, num_substructures]
        'substructure_adjacencies': torch.stack(substructure_adjacencies),  # [batch_size, num_pathways, num_substructures, num_substructures]
        'labels': torch.tensor(labels, dtype=torch.float32)  # [batch_size]
    }


In [None]:
# Dataset 초기화
dataset = DrugResponseDataset(
    gene_embeddings=gene_embeddings,
    gene_adjacencies=gene_adjacencies,
    substructure_embeddings=substructure_embeddings,
    substructure_adjacencies=substructure_adjacencies,
    labels=labels,
    sample_indices=sample_indices 
)

# DataLoader 초기화
data_loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# 데이터 확인
for batch in data_loader:
    print(f"Batch gene embeddings: {batch['gene_embeddings'].shape}")  # [8, 245, 231]
    print(f"Batch gene adjacencies: {batch['gene_adjacencies'].shape}")  # [8, 245, 231, 231]
    print(f"Batch substructure embeddings: {batch['substructure_embeddings'].shape}")  # [8, 245, 193]
    print(f"Batch substructure adjacencies: {batch['substructure_adjacencies'].shape}")  # [8, 245, 193, 193]
    print(f"Batch labels: {batch['labels'].shape}")  # [8]
    break

## (0) Embedding Layer
### - GeneEmbeddingLayer : FloatTensor -> Linear
### - SubstructureEmbeddingLayer : IntTensor -> nn.Embedding

In [9]:
GENE_EMBEDDING_DIM = 128
SUBSTRUCTURE_EMBEDDING_DIM = 128
HIDDEN_DIM = 128
FINAL_DIM = 64
OUTPUT_DIM = 1

In [10]:
class GeneEmbeddingLayer(nn.Module):
    # In)  [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_GENES(231)] 
    # Out) [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]

    def __init__(self, num_genes, embedding_dim=GENE_EMBEDDING_DIM):
        super(GeneEmbeddingLayer, self).__init__()
        self.linear = nn.Linear(1, embedding_dim)  
        self.num_genes = num_genes
        
    def forward(self, gene_values):
        gene_values = gene_values.view(-1, 1) # [BATCH_SIZE * NUM_PATHWAYS * NUM_GENES, 1]
        embedded_values = self.linear(gene_values)  # [BATCH_SIZE * NUM_PATHWAYS * NUM_GENES, GENE_EMBEDDING_DIM]
        return embedded_values.view(-1, num_pathways, num_genes, GENE_EMBEDDING_DIM) 

class SubstructureEmbeddingLayer(nn.Module):
    # In)  [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_SUBSTRUCTURES(193)]
    # Out) [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_SUBSTRUCTURES(193), SUBSTRUCTURES_EMBEDDING_DIM(128)]
    
    def __init__(self, num_substructures, embedding_dim=SUBSTRUCTURE_EMBEDDING_DIM):
        super(SubstructureEmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(num_substructures, embedding_dim) # [NUM_SUBSTRUCTURES, SUBSTRUCTURES_EMBEDDING_DIM]

    def forward(self, substructure_indices):
        return self.embedding(substructure_indices) 

## (1) CrossAttention

In [11]:
class CrossAttention(nn.Module):
    def __init__(self, query_dim, key_dim):
        super(CrossAttention, self).__init__()
        self.query_layer = nn.Linear(query_dim, query_dim)  
        self.key_layer = nn.Linear(key_dim, query_dim)      
        self.value_layer = nn.Linear(key_dim, query_dim)    

    def forward(self, query_embeddings, key_embeddings):
        query = self.query_layer(query_embeddings) 
        key = self.key_layer(key_embeddings)        
        value = self.value_layer(key_embeddings)   

        # Attention Scores
        attention_scores = torch.matmul(query, key.transpose(-1, -2))  
        attention_weights = F.softmax(attention_scores, dim=-1)        

        # Apply Attention
        attended_embeddings = torch.matmul(attention_weights, value)  
        
        return attended_embeddings



## (2) Graph Embedding

In [12]:
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool

class GraphEmbedding(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphEmbedding, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

    def forward(self, embeddings, adj_matrices):
        # adjacency matrices -> edge_index and batch tensors
        batch_graphs = []
        batch_size = embeddings.size(0)
        
        for i in range(batch_size):
            adj_matrix = adj_matrices[i]
            edge_index = adj_matrix.nonzero(as_tuple=False).t()  # [2, num_edges]
            node_features = embeddings[i]  # [num_nodes, input_dim]
            
            # Create graph data for each batch
            graph_data = Data(x=node_features, edge_index=edge_index)
            batch_graphs.append(graph_data)
        
        # Batch all graphs together
        batch = Batch.from_data_list(batch_graphs)
        
        # Apply GCN layers
        x = self.conv1(batch.x, batch.edge_index)
        x = torch.relu(x)
        x = self.conv2(x, batch.edge_index)
        
        # Global mean pooling for graph-level embedding
        graph_embedding = global_mean_pool(x, batch.batch)  # [batch_size, hidden_dim]
        return graph_embedding


## (3) DrugResponseModel

In [13]:
class DrugResponseModel(nn.Module):
    def __init__(self, num_genes, num_substructures, hidden_dim, final_dim):
        super(DrugResponseModel, self).__init__()
        # Embedding Layers
        self.gene_embedding_layer = GeneEmbeddingLayer(num_genes, GENE_EMBEDDING_DIM)
        self.substructure_embedding_layer = SubstructureEmbeddingLayer(num_substructures, SUBSTRUCTURE_EMBEDDING_DIM)
        
        # Cross Attention Layers
        self.Gene2Sub_cross_attention = CrossAttention(query_dim=GENE_EMBEDDING_DIM, key_dim=SUBSTRUCTURE_EMBEDDING_DIM)
        self.Sub2Gene_cross_attention = CrossAttention(query_dim=SUBSTRUCTURE_EMBEDDING_DIM, key_dim=GENE_EMBEDDING_DIM)

        # Graph Embedding Layers
        self.pathway_graph = GraphEmbedding(GENE_EMBEDDING_DIM, hidden_dim)
        self.drug_graph = GraphEmbedding(SUBSTRUCTURE_EMBEDDING_DIM, hidden_dim)

        # Fully Connected Layers
        self.fc1 = nn.Linear(2 * hidden_dim, final_dim)
        self.fc2 = nn.Linear(final_dim, OUTPUT_DIM)
        self.sigmoid = nn.Sigmoid()

    def forward(self, gene_embeddings, gene_adjacencies, substructure_embeddings, substructure_adjacencies):
        # Dummy code for making integer tensor
        substructure_embeddings = substructure_embeddings.int() 

        # Gene and Substructure Embeddings
        gene_embeddings = self.gene_embedding_layer(gene_embeddings)  # [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]
        substructure_embeddings = self.substructure_embedding_layer(substructure_embeddings)  # [BATCH_SIZE(8), NUM_PATHWAYS(10), NUM_SUBSTRUCTURES(193), SUBSTRUCTURE_EMBEDDING_DIM(128)]

        # Pathway and Drug Graph Embeddings List
        pathway_graph_embeddings = []
        drug_graph_embeddings = []

        # Pathway loop
        for i in range(gene_embeddings.size(1)):  
            gene_emb = gene_embeddings[:, i, :]  # [BATCH_SIZE(8), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]
            sub_emb = substructure_embeddings[:, i, :]  # [BATCH_SIZE(8), NUM_SUBSTRUCTURES(193), SUBSTRUCTURE_EMBEDDING_DIM(128)]

            # Cross attention
            updated_gene_emb = self.Gene2Sub_cross_attention(gene_emb, sub_emb) # [BATCH_SIZE(8), NUM_GENES(231), GENE_EMBEDDING_DIM(128)]
            updated_sub_emb = self.Sub2Gene_cross_attention(sub_emb, gene_emb) # [BATCH_SIZE(8), NUM_SUBSTRUCTURES(193), SUBSTRUCTURE_EMBEDDING_DIM(128)]

            # Generate graph embeddings for each pathway
            pathway_graph_embedding = self.pathway_graph(updated_gene_emb, gene_adjacencies[:, i, :, :]) # [BATCH_SIZE(8),GENE_EMBEDDING_DIM(128)]
            # drug_graph_embedding = self.drug_graph(updated_sub_emb, substructure_adjacencies[:, i, :, :]) # [BATCH_SIZE(8),SUBSTRUCTURE_EMBEDDING_DIM(128)]
            drug_graph_embedding = self.drug_graph(updated_sub_emb, substructure_adjacencies) # [BATCH_SIZE(8),SUBSTRUCTURE_EMBEDDING_DIM(128)] - adj 모든 배치에 대해 동일

            pathway_graph_embeddings.append(pathway_graph_embedding)
            drug_graph_embeddings.append(drug_graph_embedding)

        # Stack pathway and drug graph embeddings
        pathway_graph_embeddings = torch.stack(pathway_graph_embeddings, dim=1) # [BATCH_SIZE(8), NUM_PATHWAYS(10), GENE_EMBEDDING_DIM(128)]
        drug_graph_embeddings = torch.stack(drug_graph_embeddings, dim=1) # [BATCH_SIZE(8), NUM_PATHWAYS(10), SUBSTRUCTURE_EMBEDDING_DIM(128)]

        # Mean pooling
        final_pathway_embedding = torch.mean(pathway_graph_embeddings, dim=1) # [BATCH_SIZE(8), GENE_EMBEDDING_DIM(128)]
        final_drug_embedding = torch.mean(drug_graph_embeddings, dim=1) # [BATCH_SIZE(8), SUBSTRUCTURE_EMBEDDING_DIM(128)]

        # Concatenate final embeddings
        combined_embedding = torch.cat((final_pathway_embedding, final_drug_embedding), dim=-1)  # [BATCH_SIZE(8), GENE_EMBEDDING_DIM + SUBSTRUCTURE_EMBEDDING_DIM(256)]

        # Final Prediction
        x = self.fc1(combined_embedding) # [BATCH_SIZE(8), FINAL_DIM(64)]
        x = self.fc2(x) # [BATCH_SIZE(8), OUTPUT_DIM(1)]

        return self.sigmoid(x)
    

In [None]:
model = DrugResponseModel(num_genes, num_substructures, HIDDEN_DIM, FINAL_DIM)
model = model.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(f"Model device: {next(model.parameters()).device}")
print(f"CUDA is available: {torch.cuda.is_available()}")

In [None]:
# Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in data_loader:
        gene_embeddings = batch['gene_embeddings'].to(device)
        gene_adjacencies = batch['gene_adjacencies'].to(device)
        substructure_embeddings = batch['substructure_embeddings'].to(device)
        substructure_adjacencies = batch['substructure_adjacencies'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(gene_embeddings, gene_adjacencies, substructure_embeddings, substructure_adjacencies)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(data_loader):.4f}")