Papermill params:

In [1]:
workers = None
lightning = False       # Use PyTorch Lightning
tissue = "tongue"       # "tissue_general" obs filter
is_primary_data = True  # Additional obs filter
cpu = False             # Force CPU mode
census_version = "2024-07-01"
batch_size = 128
shuffle = True
learning_rate = 1e-5
n_epochs = 20

In [2]:
# Parameters
tissue = "tongue"
workers = 4


In [3]:
def is_running_in_papermill():
    import sys
    return 'pm_config' in sys.modules or 'pm_config' in globals()

is_papermill = is_running_in_papermill()
if is_papermill:
    print("Detected papermill run")

In [4]:
import tiledbsoma as soma
import torch
from sklearn.preprocessing import LabelEncoder

import tiledbsoma_ml as soma_ml

CZI_Census_Homo_Sapiens_URL = f"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2"}),
)
obs_value_filter = f"tissue_general == '{tissue}'"
if is_primary_data:
    obs_value_filter += " and is_primary_data == True"

iter_cls = soma_ml.ExperimentAxisQueryIterDataPipe if workers is None and not lightning else soma_ml.ExperimentAxisQueryIterableDataset

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = iter_cls(
        query,
        X_name="raw",
        obs_column_names=["cell_type"],
        batch_size=batch_size,
        shuffle=shuffle,
    )

print(f'{len(obs_df)} cells, {len(experiment_dataset)} batches')

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



15020 cells, 118 batches


In [5]:
# PyTorch
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()  # noqa: UP008
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs
    

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)
        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

In [6]:
# Lightning
import pytorch_lightning as pl

class LogisticRegressionLightning(pl.LightningModule):
    def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):
        super(LogisticRegressionLightning, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.cell_type_encoder = cell_type_encoder
        self.learning_rate = learning_rate
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

    def training_step(self, batch, batch_idx):
        X_batch, y_batch = batch
        # X_batch = X_batch.float()
        X_batch = torch.from_numpy(X_batch).float().to(self.device)

        # Perform prediction
        outputs = self(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute loss
        y_batch = torch.from_numpy(
            self.cell_type_encoder.transform(y_batch["cell_type"])
        ).to(self.device)
        loss = self.loss_fn(outputs, y_batch.long())

        # Compute accuracy
        train_correct = (predictions == y_batch).sum().item()
        train_accuracy = train_correct / len(predictions)

        # Log loss and accuracy
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [7]:
if cpu or not torch.cuda.is_available():
    device = "cpu"
else:
    device = "cuda"
device = torch.device(device)
input_dim = experiment_dataset.shape[1]
output_dim = len(cell_type_encoder.classes_)

dl_kwargs = {} if workers is None else dict(num_workers=workers, persistent_workers=True)
train_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)

if lightning:
    model = LogisticRegressionLightning(input_dim, output_dim, cell_type_encoder=cell_type_encoder)
    trainer = pl.Trainer(
        max_epochs=n_epochs,
        strategy="auto" if cpu else "ddp_notebook",
        accelerator="cpu" if cpu else "gpu",
        devices=1 if cpu else workers or 1,
        sync_batchnorm=True if not cpu and workers and workers > 1 else False,
        deterministic=True,
        max_time=None,
        enable_progress_bar=not is_papermill,
    )
    torch.set_float32_matmul_precision("high")
else:
    model = LogisticRegression(input_dim, output_dim).to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

switching torch multiprocessing start method from "fork" to "spawn"


In [8]:
%%time
if lightning:
    trainer.fit(model, train_dataloaders=train_dataloader)
else:
    for epoch in range(n_epochs):
        if workers is not None:
            experiment_dataset.set_epoch(epoch)
        train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



Epoch 1: Train Loss: 0.0162859 Accuracy 0.3874


Epoch 2: Train Loss: 0.0145663 Accuracy 0.4810


Epoch 3: Train Loss: 0.0141905 Accuracy 0.4972


Epoch 4: Train Loss: 0.0140145 Accuracy 0.5533


Epoch 5: Train Loss: 0.0138949 Accuracy 0.6115


Epoch 6: Train Loss: 0.0137581 Accuracy 0.6698


Epoch 7: Train Loss: 0.0136083 Accuracy 0.7734


Epoch 8: Train Loss: 0.0135195 Accuracy 0.8613


Epoch 9: Train Loss: 0.0134566 Accuracy 0.8927


Epoch 10: Train Loss: 0.0134073 Accuracy 0.9134


Epoch 11: Train Loss: 0.0133681 Accuracy 0.9246


Epoch 12: Train Loss: 0.0133218 Accuracy 0.9379


Epoch 13: Train Loss: 0.0132595 Accuracy 0.9443


Epoch 14: Train Loss: 0.0132176 Accuracy 0.9499


Epoch 15: Train Loss: 0.0131854 Accuracy 0.9537


Epoch 16: Train Loss: 0.0131572 Accuracy 0.9562


Epoch 17: Train Loss: 0.0131375 Accuracy 0.9587


Epoch 18: Train Loss: 0.0131190 Accuracy 0.9609


Epoch 19: Train Loss: 0.0131028 Accuracy 0.9621


Epoch 20: Train Loss: 0.0130887 Accuracy 0.9642
CPU times: user 1min 12s, sys: 1min 9s, total: 2min 21s
Wall time: 5min 43s


In [9]:
# TODO: split train/test
test_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)
X_batch, y_batch = next(iter(test_dataloader))
X_batch = torch.from_numpy(X_batch)
y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [10]:
import pandas as pd

model.eval()
model.to(device)
outputs = model(X_batch.to(device))
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())

cmp_df = pd.DataFrame({
    "actual cell type": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),
    "predicted cell type": predicted_cell_types,
})
right, wrong = (cmp_df['actual cell type'] == cmp_df['predicted cell type']).value_counts().values
print('Accuracy: %.1f%% (%d correct, %d incorrect)' % (100 * right / len(cmp_df), right, wrong))
pd.crosstab(cmp_df['actual cell type'], cmp_df['predicted cell type']).replace(0, '')

Accuracy: 98.4% (126 correct, 2 incorrect)


predicted cell type,basal cell,epithelial cell,fibroblast,keratinocyte,leukocyte,pericyte,vein endothelial cell
actual cell type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
basal cell,50.0,2.0,,,,,
epithelial cell,,23.0,,,,,
fibroblast,,,1.0,,,,
keratinocyte,,,,32.0,,,
leukocyte,,,,,17.0,,
pericyte,,,,,,1.0,
vein endothelial cell,,,,,,,2.0


	Command being timed: "papermill -p tissue tongue -p workers 4 benchmark.ipynb tongue/torch4.ipynb"
	User time (seconds): 92.07
	System time (seconds): 77.89
	Percent of CPU this job got: 45%
	Elapsed (wall clock) time (h:mm:ss or m:ss): 6:16.38
	Average shared text size (kbytes): 0
	Average unshared data size (kbytes): 0
	Average stack size (kbytes): 0
	Average total size (kbytes): 0
	Maximum resident set size (kbytes): 1276952
	Average resident set size (kbytes): 0
	Major (requiring I/O) page faults: 0
	Minor (reclaiming a frame) page faults: 18261723
	Voluntary context switches: 38588
	Involuntary context switches: 1936
	Swaps: 0
	File system inputs: 0
	File system outputs: 1512
	Socket messages sent: 0
	Socket messages received: 0
	Signals delivered: 0
	Page size (bytes): 4096
	Exit status: 0
