Papermill params:

In [1]:
workers = None
lightning = False       # Use PyTorch Lightning
tissue = "tongue"       # "tissue_general" obs filter
is_primary_data = True  # Additional obs filter
cpu = False             # Force CPU mode
census_version = "2024-07-01"
batch_size = 128
shuffle = True
learning_rate = 1e-5
n_epochs = 20

In [2]:
# Parameters
tissue = "embryo"


In [3]:
from os import environ as env
is_papermill = bool(env.get("PAPERMILL"))

In [4]:
import tiledbsoma as soma
import torch
from sklearn.preprocessing import LabelEncoder

import tiledbsoma_ml as soma_ml

CZI_Census_Homo_Sapiens_URL = f"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2"}),
)
obs_value_filter = f"tissue_general == '{tissue}'"
if is_primary_data:
    obs_value_filter += " and is_primary_data == True"

iter_cls = soma_ml.ExperimentAxisQueryIterDataPipe if workers is None and not lightning else soma_ml.ExperimentAxisQueryIterableDataset

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = iter_cls(
        query,
        X_name="raw",
        obs_column_names=["cell_type"],
        batch_size=batch_size,
        shuffle=shuffle,
    )

print(f'{len(obs_df)} cells, {len(experiment_dataset)} batches')

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



165937 cells, 1297 batches


In [5]:
# PyTorch
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()  # noqa: UP008
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs
    

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)
        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

In [6]:
# Lightning
import pytorch_lightning as pl

class LogisticRegressionLightning(pl.LightningModule):
    def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):
        super(LogisticRegressionLightning, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.cell_type_encoder = cell_type_encoder
        self.learning_rate = learning_rate
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

    def training_step(self, batch, batch_idx):
        X_batch, y_batch = batch
        # X_batch = X_batch.float()
        X_batch = torch.from_numpy(X_batch).float().to(self.device)

        # Perform prediction
        outputs = self(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute loss
        y_batch = torch.from_numpy(
            self.cell_type_encoder.transform(y_batch["cell_type"])
        ).to(self.device)
        loss = self.loss_fn(outputs, y_batch.long())

        # Compute accuracy
        train_correct = (predictions == y_batch).sum().item()
        train_accuracy = train_correct / len(predictions)

        # Log loss and accuracy
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [7]:
if cpu or not torch.cuda.is_available():
    device = "cpu"
else:
    device = "cuda"
device = torch.device(device)
input_dim = experiment_dataset.shape[1]
output_dim = len(cell_type_encoder.classes_)

dl_kwargs = {} if workers is None else dict(num_workers=workers, persistent_workers=True)
train_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)

if lightning:
    model = LogisticRegressionLightning(input_dim, output_dim, cell_type_encoder=cell_type_encoder)
    trainer = pl.Trainer(
        max_epochs=n_epochs,
        strategy="auto" if cpu else "ddp_notebook",
        accelerator="cpu" if cpu else "gpu",
        devices=1 if cpu else workers or 1,
        sync_batchnorm=True if not cpu and workers and workers > 1 else False,
        deterministic=True,
        max_time=None,
        enable_progress_bar=not is_papermill,
    )
    torch.set_float32_matmul_precision("high")
else:
    model = LogisticRegression(input_dim, output_dim).to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
%%time
if lightning:
    trainer.fit(model, train_dataloaders=train_dataloader)
else:
    for epoch in range(n_epochs):
        if workers is not None:
            experiment_dataset.set_epoch(epoch)
        train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")

Epoch 1: Train Loss: 0.0247633 Accuracy 0.4906


Epoch 2: Train Loss: 0.0236110 Accuracy 0.5638


Epoch 3: Train Loss: 0.0234014 Accuracy 0.6321


Epoch 4: Train Loss: 0.0233098 Accuracy 0.6391


Epoch 5: Train Loss: 0.0232599 Accuracy 0.6423


Epoch 6: Train Loss: 0.0232161 Accuracy 0.6381


Epoch 7: Train Loss: 0.0231697 Accuracy 0.6281


Epoch 8: Train Loss: 0.0231455 Accuracy 0.6364


Epoch 9: Train Loss: 0.0231295 Accuracy 0.6401


Epoch 10: Train Loss: 0.0231176 Accuracy 0.6444


Epoch 11: Train Loss: 0.0231084 Accuracy 0.6454


Epoch 12: Train Loss: 0.0231006 Accuracy 0.6486


Epoch 13: Train Loss: 0.0230945 Accuracy 0.6497


Epoch 14: Train Loss: 0.0230890 Accuracy 0.6522


Epoch 15: Train Loss: 0.0230842 Accuracy 0.6539


Epoch 16: Train Loss: 0.0230801 Accuracy 0.6555


Epoch 17: Train Loss: 0.0230763 Accuracy 0.6578


Epoch 18: Train Loss: 0.0230730 Accuracy 0.6578


Epoch 19: Train Loss: 0.0230698 Accuracy 0.6599


Epoch 20: Train Loss: 0.0230672 Accuracy 0.6606
CPU times: user 33min 50s, sys: 9min 17s, total: 43min 7s
Wall time: 8min 9s


In [9]:
# TODO: split train/test
test_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)
X_batch, y_batch = next(iter(test_dataloader))
X_batch = torch.from_numpy(X_batch)
y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))

In [10]:
import pandas as pd

model.eval()
model.to(device)
outputs = model(X_batch.to(device))
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())

cmp_df = pd.DataFrame({
    "actual cell type": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),
    "predicted cell type": predicted_cell_types,
})
right, wrong = (cmp_df['actual cell type'] == cmp_df['predicted cell type']).value_counts().values
print('Accuracy: %.1f%% (%d correct, %d incorrect)' % (100 * right / len(cmp_df), right, wrong))
pd.crosstab(cmp_df['actual cell type'], cmp_df['predicted cell type']).replace(0, '')

Accuracy: 68.0% (87 correct, 41 incorrect)


predicted cell type,GABAergic neuron,cortical interneuron,epithelial cell,inhibitory interneuron,macrophage,neural cell,neural progenitor cell,neuron,stromal cell
actual cell type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
GABAergic neuron,22.0,,,,,,,,
cortical interneuron,,2.0,,,,,,,
ependymal cell,,,,,,,1.0,,
epithelial cell,,,5.0,,,,,,
forebrain neuroblast,,2.0,,,,,,,
glial cell,1.0,,,,,,,,
glutamatergic neuron,17.0,,,,,,1.0,,
inhibitory interneuron,,,,2.0,,,,,
macrophage,,,,,7.0,,,,
microglial cell,,,,,,,2.0,,


	Command being timed: "papermill -p tissue embryo benchmark.ipynb embryo/torch.ipynb"
	User time (seconds): 2086.14
	System time (seconds): 578.45
	Percent of CPU this job got: 523%
	Elapsed (wall clock) time (h:mm:ss or m:ss): 8:29.17
	Average shared text size (kbytes): 0
	Average unshared data size (kbytes): 0
	Average stack size (kbytes): 0
	Average total size (kbytes): 0
	Maximum resident set size (kbytes): 9626452
	Average resident set size (kbytes): 0
	Major (requiring I/O) page faults: 0
	Minor (reclaiming a frame) page faults: 97734165
	Voluntary context switches: 2875770
	Involuntary context switches: 167912
	Swaps: 0
	File system inputs: 0
	File system outputs: 1384
	Socket messages sent: 0
	Socket messages received: 0
	Signals delivered: 0
	Page size (bytes): 4096
	Exit status: 0
