Papermill params:

In [1]:
workers = None
lightning = False       # Use PyTorch Lightning
tissue = "tongue"       # "tissue_general" obs filter
is_primary_data = True  # Additional obs filter
census_version = "2024-07-01"
batch_size = 128
shuffle = True
learning_rate = 1e-5
n_epochs = 20
train_split = 0.8
split_seed = 1

In [2]:
# Parameters
tissue = "tongue"
workers = 4


In [3]:
import tiledbsoma as soma
import torch
from sklearn.preprocessing import LabelEncoder

import tiledbsoma_ml as soma_ml

CZI_Census_Homo_Sapiens_URL = f"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2"}),
)
obs_value_filter = f"tissue_general == '{tissue}'"
if is_primary_data:
    obs_value_filter += " and is_primary_data == True"

iter_cls = soma_ml.ExperimentAxisQueryIterDataPipe if workers is None and not lightning else soma_ml.ExperimentAxisQueryIterableDataset

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = iter_cls(
        query,
        X_name="raw",
        obs_column_names=["cell_type"],
        batch_size=batch_size,
        shuffle=shuffle,
    )

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [4]:
# PyTorch
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()  # noqa: UP008
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs
    

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)
        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

In [5]:
# Lightning
import pytorch_lightning as pl

class LogisticRegressionLightning(pl.LightningModule):
    def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):
        super(LogisticRegressionLightning, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.cell_type_encoder = cell_type_encoder
        self.learning_rate = learning_rate
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

    def training_step(self, batch, batch_idx):
        X_batch, y_batch = batch
        # X_batch = X_batch.float()
        X_batch = torch.from_numpy(X_batch).float().to(self.device)

        # Perform prediction
        outputs = self(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute loss
        y_batch = torch.from_numpy(
            self.cell_type_encoder.transform(y_batch["cell_type"])
        ).to(self.device)
        loss = self.loss_fn(outputs, y_batch.long())

        # Compute accuracy
        train_correct = (predictions == y_batch).sum().item()
        train_accuracy = train_correct / len(predictions)

        # Log loss and accuracy
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
input_dim = experiment_dataset.shape[1]
output_dim = len(cell_type_encoder.classes_)

dl_kwargs = {} if lightning or workers is None else dict(num_workers=workers, persistent_workers=True)
train_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)

if lightning:
    model = LogisticRegressionLightning(input_dim, output_dim, cell_type_encoder=cell_type_encoder)
    trainer = pl.Trainer(max_epochs=n_epochs, devices=list(range(workers or 1)))
    torch.set_float32_matmul_precision("high")
else:
    model = LogisticRegression(input_dim, output_dim).to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

switching torch multiprocessing start method from "fork" to "spawn"


In [7]:
%%time
if lightning:
    trainer.fit(model, train_dataloaders=train_dataloader)
else:
    for epoch in range(n_epochs):
        if workers is not None:
            experiment_dataset.set_epoch(epoch)
        train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



Epoch 1: Train Loss: 0.0162257 Accuracy 0.2702


Epoch 2: Train Loss: 0.0148430 Accuracy 0.4639


Epoch 3: Train Loss: 0.0143223 Accuracy 0.5196


Epoch 4: Train Loss: 0.0140449 Accuracy 0.5608


Epoch 5: Train Loss: 0.0139076 Accuracy 0.5948


Epoch 6: Train Loss: 0.0138214 Accuracy 0.6152


Epoch 7: Train Loss: 0.0137290 Accuracy 0.6378


Epoch 8: Train Loss: 0.0136576 Accuracy 0.6617


Epoch 9: Train Loss: 0.0136091 Accuracy 0.7214


Epoch 10: Train Loss: 0.0135205 Accuracy 0.8495


Epoch 11: Train Loss: 0.0134696 Accuracy 0.8912


Epoch 12: Train Loss: 0.0134134 Accuracy 0.9058


Epoch 13: Train Loss: 0.0133501 Accuracy 0.9121


Epoch 14: Train Loss: 0.0133021 Accuracy 0.9198


Epoch 15: Train Loss: 0.0132502 Accuracy 0.9282


Epoch 16: Train Loss: 0.0132061 Accuracy 0.9360


Epoch 17: Train Loss: 0.0131800 Accuracy 0.9388


Epoch 18: Train Loss: 0.0131604 Accuracy 0.9430


Epoch 19: Train Loss: 0.0131407 Accuracy 0.9473


Epoch 20: Train Loss: 0.0131216 Accuracy 0.9505
CPU times: user 1min 22s, sys: 1min 7s, total: 2min 29s
Wall time: 6min 58s


In [8]:
# TODO: split train/test
test_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)
X_batch, y_batch = next(iter(test_dataloader))
X_batch = torch.from_numpy(X_batch)
y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [9]:
import pandas as pd

model.eval()
model.to(device)
outputs = model(X_batch.to(device))
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())

cmp_df = pd.DataFrame({
    "actual cell type": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),
    "predicted cell type": predicted_cell_types,
})
right, wrong = (cmp_df['actual cell type'] == cmp_df['predicted cell type']).value_counts().values
print('Accuracy: %.1f%% (%d correct, %d incorrect)' % (100 * right / len(cmp_df), right, wrong))
pd.crosstab(cmp_df['actual cell type'], cmp_df['predicted cell type']).replace(0, '')

Accuracy: 93.0% (119 correct, 9 incorrect)


predicted cell type,basal cell,epithelial cell,keratinocyte,leukocyte,vein endothelial cell
actual cell type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
basal cell,48.0,2.0,1.0,,
endothelial cell of artery,,,,,2.0
endothelial cell of lymphatic vessel,,,,,1.0
epithelial cell,,25.0,,,
keratinocyte,2.0,,27.0,,
leukocyte,,1.0,,17.0,
vein endothelial cell,,,,,2.0


	Command being timed: "papermill -p tissue tongue -p workers 4 benchmark.ipynb tongue/torch4.ipynb"
	User time (seconds): 100.24
	System time (seconds): 77.27
	Percent of CPU this job got: 38%
	Elapsed (wall clock) time (h:mm:ss or m:ss): 7:35.36
	Average shared text size (kbytes): 0
	Average unshared data size (kbytes): 0
	Average stack size (kbytes): 0
	Average total size (kbytes): 0
	Maximum resident set size (kbytes): 1298044
	Average resident set size (kbytes): 0
	Major (requiring I/O) page faults: 0
	Minor (reclaiming a frame) page faults: 18278497
	Voluntary context switches: 39723
	Involuntary context switches: 1378
	Swaps: 0
	File system inputs: 0
	File system outputs: 1376
	Socket messages sent: 0
	Socket messages received: 0
	Signals delivered: 0
	Page size (bytes): 4096
	Exit status: 0
