In [59]:
import scanpy as sc
import torch
from torch import nn
import numpy as np
from sklearn.model_selection import train_test_split
import lightning.pytorch as pl
from torch.optim import Adam
from torch.nn import Linear
from typing import List, Optional, Callable
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.distributions import ContinuousBernoulli
from lightning.pytorch.loggers import TensorBoardLogger

In [2]:
adata = sc.read_h5ad("../datasets/processed/pbmc68k.h5ad")

In [4]:
class MLP(torch.nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dropout: float = 0.0,
        final_activation: Optional[Callable[..., torch.nn.Module]] = None,
    ):
        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(torch.nn.Linear(in_dim, hidden_dim))
            layers.append(activation_layer())
            layers.append(torch.nn.Dropout(dropout))
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_channels[-1]))

        if final_activation is not None:
            layers.append(final_activation())

        super().__init__(*layers)


In [32]:
class MLPAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        input_dim: int,
        units_encoder: List[int],
        units_decoder: List[int],
        learning_rate: float = 0.001,
    ):
        super(MLPAutoEncoder, self).__init__()
        self.encoder = MLP(
            in_channels=input_dim,
            hidden_channels=units_encoder,
            final_activation=torch.nn.Sigmoid,
        )
        self.decoder = MLP(
            in_channels=units_encoder[-1],
            hidden_channels=units_decoder + [input_dim],
            final_activation=torch.nn.Sigmoid,
        )
        self.learning_rate = learning_rate

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

    def training_step(self, batch, batch_idx):
        noisy_data, clean_data = batch
        output = self(noisy_data)
        loss = torch.nn.functional.mse_loss(output, clean_data)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        noisy_data, clean_data = batch
        output = self(noisy_data)
        loss = torch.nn.functional.mse_loss(output, clean_data)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


In [45]:
class CellDreamerMLP(pl.LightningModule):
    def __init__(self, in_dim, out_dim=None, w=64, learning_rate: float = 0.001,):
        """
        Simplified Multi-Layer Perceptron (MLP) based on SimpleMLPTimeStep architecture.

        Args:
            in_dim (int): Input dimension.
            out_dim (int, optional): Output dimension. If None, defaults to in_dim.
            w (int, optional): Dimension of hidden layers. Defaults to 64.
        """
        super().__init__()
        if out_dim is None:
            out_dim = in_dim

        self.net = nn.Sequential(
            Linear(in_dim, w),
            nn.SELU(),
            Linear(w, w),
            nn.SELU(),
            Linear(w, w),
            nn.SELU(),
            Linear(w, out_dim),
        )
        self.save_hyperparameters()
        self.learning_rate = learning_rate

    def forward(self, x):
        """
        Forward pass of the CellDreamerMLP.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        noisy_data, clean_data = batch
        output = self(noisy_data)
        loss = torch.nn.functional.mse_loss(output, clean_data)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        noisy_data, clean_data = batch
        output = self(noisy_data)
        loss = torch.nn.functional.mse_loss(output, clean_data)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [98]:
class ResnetBlock(nn.Module):
    def __init__(self, in_dim, out_dim=None, dropout_prob=0.0):
        super().__init__()
        out_dim = in_dim if out_dim is None else out_dim
        self.out_dim = out_dim

        self.net1 = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.SiLU(),
            nn.Linear(in_dim, out_dim)
        )

        self.net2 = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(out_dim, out_dim)
        )

        if in_dim != out_dim:
            self.skip_proj = nn.Linear(in_dim, out_dim)
        else:
            self.skip_proj = None

    def forward(self, x):
        h = self.net1(x)
        h = self.net2(h)

        if self.skip_proj is not None:
            x = self.skip_proj(x)

        return x + h

class ResNet(pl.LightningModule):
    def __init__(self, in_dim, hidden_dim, n_blocks, dropout_prob=0.0):
        super().__init__()
        self.net_in = nn.Linear(in_dim, hidden_dim)

        self.blocks = nn.ModuleList([
            ResnetBlock(hidden_dim, dropout_prob=dropout_prob) for _ in range(n_blocks)
        ])

        self.net_out = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, in_dim),
        )
        self.sig = nn.Sigmoid()

    def forward(self, x):
        h = self.net_in(x)
        for block in self.blocks:
            h = block(h)
        h = self.net_out(h)
        h = self.sig(h)
        return h

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = -ContinuousBernoulli(probs=y_hat).log_prob(self.sig(y)).mean()
        mse = nn.functional.mse_loss(y_hat, y)
        self.log('train_loss', mse)
        return mse

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # print('Min of y hat: ', torch.min(y_hat))
        # print('Max of y hat: ', torch.max(y_hat))
        # print('Min of y: ', torch.min(y))
        # print('Max of y: ', torch.max(y))
        loss = -ContinuousBernoulli(probs=y_hat).log_prob(self.sig(y)).mean()
        mse = nn.functional.mse_loss(y_hat, y)
        self.log('val_loss', mse)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = -ContinuousBernoulli(probs=y_hat).log_prob(self.sig(y)).mean()
        mse = nn.functional.mse_loss(y_hat, y)
        self.log('test_loss', mse)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [90]:
def add_gaussian_noise(data, noise_factor=0.5):
    """
    Adds Gaussian noise to the data.

    Args:
    data (torch.Tensor): Original data.
    noise_factor (float): Factor to determine the amount of noise to add.

    Returns:
    torch.Tensor: Noisy data.
    """
    noise = torch.randn_like(data) * noise_factor
    return data + noise

In [91]:
class NoisyDataset(Dataset):
    def __init__(self, data, noise_factor=0.5):
        self.data = data
        self.noise_factor = noise_factor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        noisy_data = add_gaussian_noise(self.data[idx], self.noise_factor)
        return noisy_data, self.data[idx]

# Convert adata.X to a PyTorch tensor (assuming adata.X is a numpy array)
X = torch.tensor(torch.tensor(adata.X.todense()), dtype=torch.float32)

# Split the dataset into train and validation sets
X_train, X_val = train_test_split(X, test_size=0.2, random_state=42)  # 20% for validation
    
# Creating datasets
noise_factor = 0.5  # Adjust as needed
train_noisy_dataset = NoisyDataset(X_train, noise_factor=noise_factor)
val_noisy_dataset = NoisyDataset(X_val, noise_factor=noise_factor)

# Creating DataLoaders
train_loader = DataLoader(train_noisy_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_noisy_dataset, batch_size=32)

In [99]:
# Model initialization
input_dim = X_train.shape[1]  # Number of features in your data
units_encoder = [128, 64]
units_decoder = [64, 128]
# model = MLPAutoEncoder(input_dim, units_encoder, units_decoder)
# model = CellDreamerMLP(input_dim)
model = ResNet(input_dim, hidden_dim=64, n_blocks=2)
# Logger

# Specify your desired path for TensorBoard logs
log_path = "/lustre/groups/ml01/workspace/till.richter/scvdm_trained_models/"  # Replace "my_path" with your actual path

# Logger
logger = TensorBoardLogger(log_path, name="ResNet_Sigmoid")

# Training with PyTorch Lightning Trainer
trainer = pl.Trainer(max_epochs=50, logger=logger)  # Adjust epochs as needed
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /lustre/groups/ml01/workspace/till.richter/scvdm_trained_models/ResNet_Sigmoid
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | net_in  | Linear     | 128 K 
1 | blocks  | ModuleList | 17.2 K
2 | net_out | Sequential | 130 K 
3 | sig     | Sigmoid    | 0     
---------------------------------------
275 K     Trainable params
0         Non-trainable params
275 K     Total params
1.101     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [72]:
a = torch.randn((5,5))
torch.max(a)

tensor(2.5268)