## Imports

In [1]:
from healnet.models import HealNet
from healnet.etl import MMDataset
import torch
import einops
from torch.utils.data import Dataset, DataLoader
from typing import *

%load_ext autoreload
%autoreload 2

## Synthetic modalities

We instantiate a synthetic multimodal dataset for demo purposes. 

In [2]:
n = 1000 # number of samples
b = 4 # batch size
img_c = 3 # image channels
tab_c = 1 # tabular channels
tab_d = 5000 # tabular features
h = 512 # image height
w = 512 # image width
n_classes = 2 # classification

tab_tensor = torch.rand(size=(n, tab_c, tab_d)) # assume 5k tabular features
img_tensor = torch.rand(size=(n, img_c, h, w)) # c h w


# derive a target
target = torch.rand(size=(n,))

In [16]:
import os
import torch
import pandas as pd
import h5py
from torch.utils.data import Dataset, DataLoader

# Define the custom dataset class
class CustomDataset(Dataset):
    def __init__(self, df, img_features_path):
        self.df = df.reset_index(drop=True)
        self.img_features_path = img_features_path
        self.tabular_columns = [col for col in df.columns if col.startswith('cnv')]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Extract row data
        row = self.df.iloc[idx]
        wsi_file_name = row["wsi_file_name"][:-5]  # Adjust as needed
        target_label = row["progressor_status"]
        
        # Convert target label from 'NP'/'P' to 0/1
        if target_label == 'P':
            target = 1
        else:
            target = 0
        
        # Load tabular features
        tabular_features = row[self.tabular_columns].values.astype('float32')
        tabular_tensor = torch.tensor(tabular_features)
        
        # Load image features from .h5 file
        img_features_file = os.path.join(self.img_features_path, wsi_file_name + ".h5")
        with h5py.File(img_features_file, 'r') as h5_file:
            # Extract 'cluster_features' dataset
            if "cluster_features" in h5_file:
                cluster_features = h5_file["cluster_features"][:]
            else:
                raise KeyError(f"'cluster_features' not found in {img_features_file}")
            
        # Convert cluster features to a tensor
        img_features_tensor = torch.tensor(cluster_features, dtype=torch.float32)
        
        # Prepare the target tensor
        target_tensor = torch.tensor(target, dtype=torch.float32)
        
        return {
            'tabular': tabular_tensor,
            'image': img_features_tensor,
            'target': target_tensor
        }

# Paths to your data
path_to_img_features = "/scratchc/fmlab/zuberi01/masters/saved_patches/40x_400/features2/h5_files/"
path_to_csv = "/scratchc/fmlab/zuberi01/phd/sequoia-pub/examples/matching_rows_sequoia.csv"

# Load the CSV file
df = pd.read_csv(path_to_csv)

# Split the DataFrame into train, val, and test sets
train_df = df[df["split_0"] == "train"]
val_df = df[df["split_0"] == "val"]
test_df = df[df["split_0"] == "test"]

# Create datasets for each split
train_dataset = CustomDataset(train_df, path_to_img_features)
val_dataset = CustomDataset(val_df, path_to_img_features)
test_dataset = CustomDataset(test_df, path_to_img_features)

# Define loader arguments
loader_args = {
    "batch_size": 4,  # Adjust batch size as needed
    "shuffle": True, 
    "num_workers": 8, 
    "pin_memory": True, 
    "multiprocessing_context": "fork", 
    "persistent_workers": True, 
}

# Create DataLoaders for each dataset
train_loader = DataLoader(train_dataset, **loader_args)
val_loader = DataLoader(val_dataset, **loader_args)
test_loader = DataLoader(test_dataset, **loader_args)

# Example usage: iterate over the training data
for batch in train_loader:
    tabular_data = batch['tabular']
    image_data = batch['image']
    targets = batch['target']
    
    # Now you can feed tabular_data and image_data into your model
    # For example:
    # outputs = model(tabular_data, image_data)
    # loss = loss_fn(outputs, targets)
    # ... (rest of your training loop)


In [35]:
# Import the HealNet class
from healnet.models.healnet import HealNet

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the number of classes
num_classes = 2  # Adjust based on your dataset

# Instantiate the HealNet model
model = HealNet(
    modalities=2, 
    input_channels=[1, 1],  # Number of channels per modality
    input_axes=[1, 1],       # Updated to match the number of axes per modality
    num_classes=num_classes,
    # Add other parameters as required by your model configuration
).to(device)


In [36]:
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from healnet.models.healnet import HealNet

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define the number of classes
num_classes = 2  # Adjust based on your dataset

# Instantiate the HealNet model with updated input_axes
model = HealNet(
    modalities=2, 
    input_channels=[1, 1],  # Number of channels per modality
    input_axes=[1, 1],       # Updated to match the number of axes per modality
    num_classes=num_classes,
    # Add other parameters as required by your model configuration
).to(device)

print("HealNet model instantiated successfully.")

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Number of epochs
num_epochs = 10  # Adjust as needed

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, batch in enumerate(tqdm(train_loader)):
        # Retrieve data and move to device
        tabular_data = batch['tabular'].to(device).unsqueeze(1)  # Shape: [batch_size, 1, num_tabular_features]
        image_data = batch['image'].to(device).unsqueeze(1)      # Shape: [batch_size, 1, num_image_features]
        targets = batch['target'].long().to(device)             # Shape: [batch_size]

        # Debugging: Print shapes for the first batch of each epoch
        if batch_idx == 0:
            print(f"\nEpoch {epoch+1}, Batch {batch_idx+1}")
            print(f"Adjusted Tabular Data Shape: {tabular_data.shape}")
            print(f"Adjusted Image Data Shape: {image_data.shape}")
            print(f"Targets Shape: {targets.shape}")

        # Prepare inputs
        inputs = [tabular_data, image_data]

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)                       # Expected Shape: [batch_size, num_classes]

        # Debugging: Print outputs shape and sample outputs for the first batch
        if batch_idx == 0:
            print(f"Outputs Shape: {outputs.shape}")
            print(f"Outputs Sample: {outputs[:2]}")

        # Compute loss
        loss = loss_fn(outputs, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate loss
        running_loss += loss.item()

    # Calculate average loss for the epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"\nEpoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

    # Validation loop
    model.eval()
    val_running_loss = 0.0
    with torch.no_grad():
        for val_batch_idx, batch in enumerate(val_loader):
            tabular_data = batch['tabular'].to(device).unsqueeze(1)  # Shape: [batch_size, 1, num_tabular_features]
            image_data = batch['image'].to(device).unsqueeze(1)      # Shape: [batch_size, 1, num_image_features]
            targets = batch['target'].long().to(device)             # Shape: [batch_size]

            # Prepare inputs
            inputs = [tabular_data, image_data]

            # Forward pass
            outputs = model(inputs)
            val_loss = loss_fn(outputs, targets)
            val_running_loss += val_loss.item()

            # Debugging: Print validation outputs shape and sample outputs for the first batch
            if val_batch_idx == 0:
                print(f"Validation Outputs Shape: {outputs.shape}")
                print(f"Validation Outputs Sample: {outputs[:2]}")

    val_epoch_loss = val_running_loss / len(val_loader)
    print(f"Validation Loss: {val_epoch_loss:.4f}\n")

# After training, evaluate on the test set
model.eval()
test_running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for test_batch_idx, batch in enumerate(tqdm(test_loader)):
        tabular_data = batch['tabular'].to(device).unsqueeze(1)  # Shape: [batch_size, 1, num_tabular_features]
        image_data = batch['image'].to(device).unsqueeze(1)      # Shape: [batch_size, 1, num_image_features]
        targets = batch['target'].long().to(device)             # Shape: [batch_size]

        # Prepare inputs
        inputs = [tabular_data, image_data]

        # Forward pass
        outputs = model(inputs)
        test_loss = loss_fn(outputs, targets)
        test_running_loss += test_loss.item()

        # Compute predictions
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

        # Debugging: Print test outputs shape and sample outputs for the first batch
        if test_batch_idx == 0:
            print(f"Test Outputs Shape: {outputs.shape}")
            print(f"Test Outputs Sample: {outputs[:2]}")

test_epoch_loss = test_running_loss / len(test_loader)
test_accuracy = 100 * correct / total
print(f"Test Loss: {test_epoch_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")


Using device: cpu
HealNet model instantiated successfully.


  0%|          | 0/110 [00:00<?, ?it/s]


Epoch 1, Batch 1
Adjusted Tabular Data Shape: torch.Size([4, 1, 561])
Adjusted Image Data Shape: torch.Size([4, 1, 100, 1024])
Targets Shape: torch.Size([4])





AssertionError: input data for modality 2 must hav the same number of axis as the input axis parameter

In [16]:
data = MMDataset([tab_tensor, img_tensor], target)
train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split

loader_args = {
    "shuffle": True, 
    "num_workers": 8, 
    "pin_memory": True, 
    "multiprocessing_context": "fork", 
    "persistent_workers": True, 
}

train_loader = DataLoader(train, **loader_args)
val_loader = DataLoader(val, **loader_args)
test_loader = DataLoader(test, **loader_args)
# example use

In [17]:
# example use
[tab_sample, img_sample], target = data[0]

# emulate batch dimension
tab_sample = einops.repeat(tab_sample, 'c d -> b c d', b=1)
img_sample = einops.repeat(img_sample, 'c h w -> b c (h w)', b=1)

In [18]:
img_sample.shape

torch.Size([1, 3, 262144])

In [19]:
model = HealNet(
            modalities=2, 
            input_channels=[tab_c, img_c], 
            input_axes=[1, 1], # channel axes (0-indexed)
            num_classes = n_classes,  
        )

In [21]:
# forward pass
model([tab_sample, img_sample])

tensor([[0.9904, 0.2519]], grad_fn=<AddmmBackward0>)