<a href="https://colab.research.google.com/github/sevendaystoglory/temp/blob/main/rna-histone-interaction-training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import torch
em = torch.load("/content/ncbi_training_data/embeddings_0_part_0.pt")

  em = torch.load("/content/ncbi_training_data/embeddings_0_part_0.pt")


In [17]:
len(em[0]['Interactor1.Embeddings'])

640

In [22]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# Define the InteractionPredictor model with a corrected forward method.
class InteractionPredictor(nn.Module):
    def __init__(self, gene_dim=640, histone_dim=768, common_dim=256, pooling='cls'):
        super(InteractionPredictor, self).__init__()
        self.pooling = pooling
        self.gene_proj = nn.Linear(gene_dim, common_dim)
        self.histone_proj = nn.Linear(histone_dim, common_dim)

    def forward(self, gene_embeddings, histone_embeddings, true_score):
        # Pool the embeddings to obtain a single representation per modality
        if self.pooling == 'cls':
            gene_repr = gene_embeddings
            histone_repr = histone_embeddings
        elif self.pooling == 'mean':
            gene_repr = gene_embeddings.mean(dim=1)
            histone_repr = histone_embeddings.mean(dim=1)
        else:
            raise ValueError("Pooling method not recognized. Choose 'cls' or 'mean'.")

        # Project into a common latent space
        gene_proj = self.gene_proj(gene_repr)
        histone_proj = self.histone_proj(histone_repr)

        # Compute cosine similarity (returns values between -1 and 1)
        predicted_score = F.cosine_similarity(gene_proj, histone_proj, dim=1)

        # Convert cosine similarity to range [0, 1]
        predicted_score = (predicted_score + 1) / 2

        # Compute regression loss between predicted score and true score
        criterion = nn.MSELoss()
        loss = criterion(predicted_score, true_score)
        return loss

# Custom dataset to load our interaction data from .pt files.
class InteractionDataset(Dataset):
    def __init__(self, data_list):
        """
        data_list: a list of dictionaries with keys:
                   'Interactor1.Embeddings', 'Interactor2.Embeddings', and 'score'
        """
        self.data = data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        gene_embedding = sample['Interactor1.Embeddings']
        histone_embedding = sample['Interactor2.Embeddings']
        score = sample['score']

        # Ensure embeddings are tensors; if not, convert them.
        if not torch.is_tensor(gene_embedding):
            gene_embedding = torch.tensor(gene_embedding, dtype=torch.float)
        if not torch.is_tensor(histone_embedding):
            histone_embedding = torch.tensor(histone_embedding, dtype=torch.float)

        # Add a batch dimension if needed (here assuming each sample is 2D: [tokens, dim])
        # If your embeddings are already 2D (e.g., [1, gene_dim]) this is fine.
        # Otherwise, adjust as needed.

        # Convert score to a float tensor
        score = torch.tensor(score, dtype=torch.float)
        return gene_embedding, histone_embedding, score

# Set the directory containing the .pt files
data_dir = "ncbi_training_data"

# Load all .pt files and combine the lists
all_data = []
for file in os.listdir(data_dir):
    if file.endswith('.pt'):
        file_path = os.path.join(data_dir, file)
        data = torch.load(file_path)
        all_data.extend(data)

# Split the data into training and test sets (80% train, 20% test)
train_data, test_data = train_test_split(all_data, test_size=0.2, random_state=42)

# Create Dataset and DataLoader instances
train_dataset = InteractionDataset(train_data)
test_dataset = InteractionDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# For reproducibility
torch.manual_seed(42)

# Initialize the model, optimizer, and number of training epochs.
model = InteractionPredictor(gene_dim=640, histone_dim=768, common_dim=256, pooling='cls')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 10

  data = torch.load(file_path)


In [24]:
# Evaluate before the training
model.eval()
test_loss = 0.0
with torch.no_grad():
    for gene_embeddings, histone_embeddings, true_score in test_loader:
        true_score = true_score.float()
        loss = model(gene_embeddings, histone_embeddings, true_score)
        test_loss += loss.item() * gene_embeddings.size(0)
    test_loss /= len(test_dataset)
print(f"Test Loss: {test_loss:.4f}")

Test Loss: 0.1712


In [25]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for gene_embeddings, histone_embeddings, true_score in train_loader:
        optimizer.zero_grad()
        # Ensure true_score has the correct shape (batch_size,)
        true_score = true_score.float()
        loss = model(gene_embeddings, histone_embeddings, true_score)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * gene_embeddings.size(0)
    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} training loss: {epoch_loss:.4f}")

# Evaluate on test data
model.eval()
test_loss = 0.0
with torch.no_grad():
    for gene_embeddings, histone_embeddings, true_score in test_loader:
        true_score = true_score.float()
        loss = model(gene_embeddings, histone_embeddings, true_score)
        test_loss += loss.item() * gene_embeddings.size(0)
    test_loss /= len(test_dataset)
print(f"Test Loss: {test_loss:.4f}")

Epoch 1/10 training loss: 0.0102
Epoch 2/10 training loss: 0.0070
Epoch 3/10 training loss: 0.0068
Epoch 4/10 training loss: 0.0067
Epoch 5/10 training loss: 0.0065
Epoch 6/10 training loss: 0.0063
Epoch 7/10 training loss: 0.0063
Epoch 8/10 training loss: 0.0062
Epoch 9/10 training loss: 0.0062
Epoch 10/10 training loss: 0.0060
Test Loss: 0.0064
