In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

# Cross-Attention Module
class CrossAttention(nn.Module):
    def __init__(self, embed_dim):
        super(CrossAttention, self).__init__()
        self.query_layer = nn.Linear(embed_dim, embed_dim)
        self.key_layer = nn.Linear(embed_dim, embed_dim)
        self.value_layer = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value):
        # Compute attention weights
        query = self.query_layer(query)
        key = self.key_layer(key)
        value = self.value_layer(value)
        attention_weights = self.softmax(torch.matmul(query, key.transpose(-2, -1)))
        
        # Apply attention
        attended_features = torch.matmul(attention_weights, value)
        return attended_features

# Image Encoder
class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()
        self.fc = nn.Linear(512, 256)

    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)
        return x

# Gene Encoder
class GeneEncoder(nn.Module):
    def __init__(self, input_dim):
        super(GeneEncoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )

    def forward(self, x):
        x = self.fc(x)
        return x


class SiameseMultiModalModel(nn.Module):
    def __init__(self, gene_input_dim):
        super(SiameseMultiModalModel, self).__init__()
        self.image_encoder = ImageEncoder()
        self.gene_encoder = GeneEncoder(gene_input_dim)
        self.cross_attention = CrossAttention(embed_dim=256)
        self.bilinear_pooling = nn.Bilinear(256, 256, 128)
        self.projection_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )

    def forward(self, image1, gene1, image2, gene2):
        # Encode first pair
        img_features1 = self.image_encoder(image1)
        gene_features1 = self.gene_encoder(gene1)

        # Encode second pair
        img_features2 = self.image_encoder(image2)
        gene_features2 = self.gene_encoder(gene2)

        # Cross-attention for first pair
        attended_img1 = self.cross_attention(gene_features1.unsqueeze(1), 
                                             img_features1.unsqueeze(1), 
                                             img_features1.unsqueeze(1)).squeeze(1)
        attended_gene1 = self.cross_attention(img_features1.unsqueeze(1), 
                                              gene_features1.unsqueeze(1), 
                                              gene_features1.unsqueeze(1)).squeeze(1)

        # Cross-attention for second pair
        attended_img2 = self.cross_attention(gene_features2.unsqueeze(1), 
                                             img_features2.unsqueeze(1), 
                                             img_features2.unsqueeze(1)).squeeze(1)
        attended_gene2 = self.cross_attention(img_features2.unsqueeze(1), 
                                              gene_features2.unsqueeze(1), 
                                              gene_features2.unsqueeze(1)).squeeze(1)

        # Bilinear pooling for both pairs
        fused_features1 = self.bilinear_pooling(attended_img1, attended_gene1)
        fused_features2 = self.bilinear_pooling(attended_img2, attended_gene2)

        # Project to shared embedding space
        embed1 = self.projection_head(fused_features1)
        embed2 = self.projection_head(fused_features2)

        return embed1, embed2

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embed1, embed2, label):
        # Compute pairwise distance
        distance = nn.functional.pairwise_distance(embed1, embed2)
        
        # Contrastive loss
        loss = torch.mean((1 - label) * torch.pow(distance, 2) + 
                          label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2))
        return loss

import random
import torch.utils.data as data

class PatchDataset(data.Dataset):
    def __init__(self, patches, gene_data, labels, max_patches_per_patient=10):
        self.patches = patches  # Dict with patient_id -> list of image patches
        self.gene_data = gene_data  # Dict with patient_id -> gene expression
        self.labels = labels  # Dict with patient_id -> label
        self.max_patches_per_patient = max_patches_per_patient

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        patient_id = list(self.patches.keys())[idx]
        all_patches = self.patches[patient_id]
        
        # Randomly sample patches
        if len(all_patches) > self.max_patches_per_patient:
            selected_patches = random.sample(all_patches, self.max_patches_per_patient)
        else:
            selected_patches = all_patches
        
        patches_tensor = torch.stack(selected_patches)  # Stack patches
        gene_tensor = self.gene_data[patient_id]
        label = self.labels[patient_id]
        
        return patches_tensor, gene_tensor, label




ModuleNotFoundError: No module named 'torchvision'

In [None]:


from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

class HistopathologyDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = HistopathologyDataset(image_paths, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)



metadata = pd.read_csv("metadata.csv")
metadata = pd.get_dummies(metadata, columns=["Subtype"])

In [None]:
from torch.optim.lr_scheduler import StepLR

# Initialize model, optimizer, and scheduler
model = MultiModalWithMetadata(gene_input_dim=10000, metadata_dim=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Reduce LR by 0.5 every 5 epochs

# Training loop
for epoch in range(20):  # Adjust number of epochs
    for images, genes, metadata, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(images, genes, metadata)
        loss = criterion(outputs, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()

    scheduler.step()  # Update learning rate
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

In [None]:
from torchvision.transforms.functional import normalize

def generate_gradcam(image, model, target_layer):
    gradcam = GradCAM(model=model, target_layer=target_layer, use_cuda=True)
    heatmap = gradcam(image)
    return heatmap

import shap

# SHAP values for gene encoder
explainer = shap.DeepExplainer(model.gene_encoder, gene_data_sample)
shap_values = explainer.shap_values(gene_test_data)
shap.summary_plot(shap_values, gene_test_data)

In [14]:
# process images
import os
import cv2
import torch
from tqdm import tqdm

base_path = "/home/ramanuja-simha/Downloads/28050083/TCGA-BRCA-A2-DEEPMED-TILES/BLOCKS_NORM_MACENKO/"  
image_size = (256, 256)  
max_patches_per_patient = 1 
mean = [0.485, 0.456, 0.406]  
std = [0.229, 0.224, 0.225]   
all_patient_data_image = {}
patient_folders = os.listdir(base_path)

for patient in tqdm(patient_folders, desc="Processing Patients"):
    patient_path = os.path.join(base_path, patient)
    patient = patient.split("-")[:3]
    patient = "-".join(patient)
    patient += "-01"
    if not os.path.isdir(patient_path):
        continue
    
    processed_slices = []
    slice_files = os.listdir(patient_path)
    for slice_file in slice_files[:max_patches_per_patient]:
        slice_path = os.path.join(patient_path, slice_file)
        try:
            img = cv2.imread(slice_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)                  
            img = cv2.resize(img, image_size)
            img = img / 255.0
            img = (img - mean) / std
            img_tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
            processed_slices.append(img_tensor)
        except Exception as e:
            print(f"Error processing {slice_path}: {e}")
            continue
    if processed_slices:
        all_patient_data_image[patient] = torch.stack(processed_slices)  
print("patients:",len(all_patient_data_image))
print("patient:",next(iter(all_patient_data_image.keys())))
print("image features:",all_patient_data_image[next(iter(all_patient_data_image.keys()))].shape)

Processing Patients: 100%|███████████████████| 100/100 [00:00<00:00, 103.49it/s]

patients: 100
patient: TCGA-A2-A0ER-01
image features: torch.Size([1, 3, 256, 256])





In [15]:
# process gene expr profiling
import pandas as pd
from sklearn.impute import KNNImputer
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold
from sklearn.decomposition import PCA

file_path = "/home/ramanuja-simha/Downloads/28050083/TCGA-BRCA-RNA-Seq.csv"
gene_expression_df = pd.read_csv(file_path, sep="\t", header=0, index_col=0)  
print("gene expr:",gene_expression_df.shape)

patient_ids = list(all_patient_data_image.keys())
print("patient ids:",len(patient_ids))

filtered_gene_expression_df = gene_expression_df[patient_ids]
transposed_df = filtered_gene_expression_df.T
print("filter:",transposed_df.shape)

row_ids = transposed_df.index.tolist()
imputer = KNNImputer(n_neighbors=5)
gene_data_imputed = imputer.fit_transform(transposed_df)
print("imputer:",gene_data_imputed.shape)
scaler = StandardScaler()
gene_data_normalized = scaler.fit_transform(gene_data_imputed)
print("scaler:",gene_data_normalized.shape)
selector = VarianceThreshold(threshold=0.01)
gene_data_selected = selector.fit_transform(gene_data_normalized)
print("selector:",gene_data_selected.shape)
pca = PCA(n_components=100)
gene_data_reduced = pca.fit_transform(gene_data_selected)
print("pca:",gene_data_reduced.shape)
gene_data_reduced = pd.DataFrame(gene_data_reduced, index=row_ids)

all_patient_data_geneexpr = gene_data_reduced.to_dict(orient='index')
print("patients:",len(all_patient_data_geneexpr))
print("patient:",next(iter(all_patient_data_geneexpr.keys())))
print("gene expr features:",len(all_patient_data_geneexpr[next(iter(all_patient_data_geneexpr.keys()))]))

gene expr: (20530, 1218)
patient ids: 100
filter: (100, 20530)
imputer: (100, 20530)
scaler: (100, 20530)
selector: (100, 19997)
pca: (100, 100)
patients: 100
patient: TCGA-A2-A0ER-01
gene expr features: 100


In [21]:
import pandas as pd

metadata_path = "/home/ramanuja-simha/Downloads/28050083/TCGA-BRCA-A2-target_variable.xlsx"
metadata_df = pd.read_excel(metadata_path, engine='openpyxl')
col2_col4_dict = dict(zip(metadata_df.iloc[:, 1], metadata_df.iloc[:, 3]))
all_patient_data = {key: col2_col4_dict[key] for key in col2_col4_dict if key in all_patient_data_image}
print(len(all_patient_data))
print(set(all_patient_data.values()))
# Identify the patient IDs with class "x" in all_patient_data
patients_to_remove = [patient_id for patient_id, class_label in all_patient_data.items() if class_label == "x"]

# Remove these patient IDs from all three dictionaries
for patient_id in patients_to_remove:
    # Remove from all_patient_data
    if patient_id in all_patient_data:
        del all_patient_data[patient_id]
    
    # Remove from all_patient_data_image
    if patient_id in all_patient_data_image:
        del all_patient_data_image[patient_id]
    
    # Remove from all_patient_data_geneexpr
    if patient_id in all_patient_data_geneexpr:
        del all_patient_data_geneexpr[patient_id]

# Verify the changes
print(f"Remaining patient IDs in all_patient_data: {len(all_patient_data.keys())}")
print(f"Remaining patient IDs in all_patient_data_image: {len(all_patient_data_image.keys())}")
print(f"Remaining patient IDs in all_patient_data_geneexpr: {len(all_patient_data_geneexpr.keys())}")

93
{0, 1}
Remaining patient IDs in all_patient_data: 93
Remaining patient IDs in all_patient_data_image: 93
Remaining patient IDs in all_patient_data_geneexpr: 93


In [22]:
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from torch import nn, optim
from sklearn.preprocessing import LabelEncoder

class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        # Assuming input image size is 256x256 with 3 channels
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # Convolution layer for feature extraction
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # Additional convolution
        self.pool = nn.MaxPool2d(2, 2)  # Pooling layer
        self.fc1 = nn.Linear(64 * 64 * 64, 512)  # Fully connected layer
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 2)  # Output layer for binary classification

    def forward(self, x):
        # x shape is (batch_size, num_patches, 3, 256, 256)
        batch_size, num_patches, c, h, w = x.shape
        
        # Process each patch separately through convolution layers
        patch_features = []
        for i in range(num_patches):
            patch = x[:, i, :, :, :]  # Extract one patch
            patch = self.pool(torch.relu(self.conv1(patch)))  # Apply conv1 + pooling
            patch = self.pool(torch.relu(self.conv2(patch)))  # Apply conv2 + pooling
            patch = patch.view(patch.size(0), -1)  # Flatten the patch
            patch_features.append(patch)
        
        # Stack the patch features and aggregate across patches (e.g., average)
        patch_features = torch.stack(patch_features, dim=1)  # Shape: (batch_size, num_patches, features)
        
        # Aggregate features from all patches (simple average)
        aggregated_features = patch_features.mean(dim=1)  # Average across patches
        
        # Pass the aggregated features through the fully connected layers
        x = torch.relu(self.fc1(aggregated_features))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

class GeneExpressionModel(nn.Module):
    def __init__(self):
        super(GeneExpressionModel, self).__init__()
        self.fc1 = nn.Linear(100, 512)  # Adjust input size for gene expression data
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 2)  # Binary classification

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def prepare_data(data_dict, target_dict, validation_size=0.2, test_size=0.2):
    # Initialize the LabelEncoder
    label_encoder = LabelEncoder()
    
    # Sort the patient IDs to ensure alignment between features and labels
    patient_ids = sorted(data_dict.keys())
    
    # Convert dictionaries to lists of features and target labels
    features = []
    targets = []

    # Get all target labels (using the sorted patient IDs)
    all_targets = [target_dict[patient_id] for patient_id in patient_ids]
    
    # Fit the LabelEncoder on the target labels
    label_encoder.fit(all_targets)
    
    # Encode targets into integers
    encoded_targets = label_encoder.transform(all_targets)
    
    for patient_id in patient_ids:
        # Ensure the features are converted to tensor
        feature_tensor = torch.tensor(data_dict[patient_id], dtype=torch.float32)
        features.append(feature_tensor)
        
        # Append the corresponding encoded target label
        targets.append(encoded_targets[patient_ids.index(patient_id)])

    # Stack the features into a single tensor
    features = torch.stack(features)
    targets = torch.tensor(targets, dtype=torch.long)  # Ensure target is a tensor with correct dtype

    # Train-test-validation split
    # First, split into train+val and test
    X_train_val, X_test, y_train_val, y_test = train_test_split(features, targets, test_size=test_size, random_state=42)
    
    # Then, split train+val into train and validation
    X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=validation_size / (1 - test_size), random_state=42)
    
    return X_train, X_val, X_test, y_train, y_val, y_test

# 2. **Build Model for Image Modality (Fully Connected Example)**
def train_model(model, train_data, train_labels, validation_data=None, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        print("epoch:",epoch)
        model.train()  # Ensure the model is in training mode
        optimizer.zero_grad()
        
        # Forward pass on the training data
        outputs = model(train_data)
        loss = criterion(outputs, train_labels)
        
        # Backward pass and optimization step
        loss.backward()
        optimizer.step()
        
        # If validation data is provided, evaluate on the validation set
        if validation_data:
            model.eval()  # Set the model to evaluation mode for validation
            val_data, val_labels = validation_data
            with torch.no_grad():  # No need to compute gradients for validation
                val_outputs = model(val_data)
                val_loss = criterion(val_outputs, val_labels)
                
                # Calculate accuracy or F1 score for validation
                _, predicted = torch.max(val_outputs, 1)  # Get the predicted class labels
                correct = (predicted == val_labels).sum().item()
                val_accuracy = correct / len(val_labels)
                
                # Optionally, you could calculate other metrics like F1 score:
                val_f1 = f1_score(val_labels.cpu(), predicted.cpu(), average='weighted')
                
                print(f"Epoch {epoch+1}/{epochs} - Train Loss: {loss.item():.4f}, "
                      f"Validation Loss: {val_loss.item():.4f}, "
                      f"Validation Accuracy: {val_accuracy:.4f}, "
                      f"Validation F1: {val_f1:.4f}")
        else:
            print(f"Epoch {epoch+1}/{epochs} - Train Loss: {loss.item():.4f}")
        
    return model

def evaluate_model(model, test_data, test_labels):
    with torch.no_grad():
        outputs = model(test_data)
        _, predicted = torch.max(outputs, 1)
        accuracy = accuracy_score(test_labels.numpy(), predicted.numpy())
        f1 = f1_score(test_labels.numpy(), predicted.numpy(), average='weighted')
    return accuracy, f1

In [None]:
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from torch import nn, optim

X_train_img, X_val_img, X_test_img, y_train_img, y_val_img, y_test_img = prepare_data(all_patient_data_image, all_patient_data)
print(f"Shape of input data before model: {X_train_img.shape}")
print(f"Shape of a single image: {X_train_img[0].shape}")
# Initialize, train, and evaluate the Image Model
image_model = ImageModel()
trained_image_model = train_model(image_model, X_train_img, y_train_img, validation_data=(X_val_img, y_val_img))
img_acc, img_f1 = evaluate_model(trained_image_model, X_test_img, y_test_img)

print(f"Image Modality - Accuracy: {img_acc:.4f}, F1 Score: {img_f1:.4f}")



  feature_tensor = torch.tensor(data_dict[patient_id], dtype=torch.float32)


Shape of input data before model: torch.Size([55, 1, 3, 256, 256])
Shape of a single image: torch.Size([1, 3, 256, 256])


In [None]:
# 5. **Build Model for Gene Expression Modality (Fully Connected Example)**


# 6. **Train Gene Expression Model**

def train_gene_model(model, train_data, train_labels, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(train_data)
        loss = criterion(outputs, train_labels)
        loss.backward()
        optimizer.step()
        
    return model

# 7. **Evaluate Gene Expression Model**

def evaluate_gene_expression_model(model, test_data, test_labels):
    with torch.no_grad():
        outputs = model(test_data)
        _, predicted = torch.max(outputs, 1)
        accuracy = accuracy_score(test_labels.numpy(), predicted.numpy())
        f1 = f1_score(test_labels.numpy(), predicted.numpy(), average='weighted')
    return accuracy, f1

# Prepare Gene Expression Data
X_train_gene, X_test_gene, y_train_gene, y_test_gene = prepare_data(all_patient_data_geneexpr, all_patient_data_class)

# Initialize, train, and evaluate the Gene Expression Model
gene_model = GeneExpressionModel()
trained_gene_model = train_gene_model(gene_model, X_train_gene, y_train_gene)
gene_acc, gene_f1 = evaluate_gene_expression_model(trained_gene_model, X_test_gene, y_test_gene)

print(f"Gene Expression Modality - Accuracy: {gene_acc:.4f}, F1 Score: {gene_f1:.4f}")