Papermill params:

In [1]:
workers = None
lightning = False       # Use PyTorch Lightning
tissue = "tongue"       # "tissue_general" obs filter
is_primary_data = True  # Additional obs filter
cpu = False             # Force CPU mode
census_version = "2024-07-01"
batch_size = 128
shuffle = True
learning_rate = 1e-5
n_epochs = 20
is_papermill = False    # Papermill invocations should set this to True; `tdbsml benchmark` does this automatically

In [2]:
tissue = "ovary"

In [3]:
import tiledbsoma as soma
import torch
from sklearn.preprocessing import LabelEncoder

from tiledbsoma_ml import ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset, experiment_dataloader

CZI_Census_Homo_Sapiens_URL = f"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2"}),
)
obs_value_filter = f"tissue_general == '{tissue}'"
if is_primary_data:
    obs_value_filter += " and is_primary_data == True"

iter_cls = ExperimentAxisQueryIterDataPipe if workers is None and not lightning else ExperimentAxisQueryIterableDataset

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = iter_cls(
        query,
        X_name="raw",
        obs_column_names=["cell_type"],
        batch_size=batch_size,
        shuffle=shuffle,
    )

print(f'{len(obs_df)} cells, {len(experiment_dataset)} batches')

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



53751 cells, 420 batches


In [4]:
%%time
with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter="is_primary_data == True")
) as query:
    ts = (
        query
        .obs(column_names=["soma_joinid", "tissue_general"])
        .concat()
        .to_pandas()
        .drop(columns='is_primary_data')
        .rename(columns={'tissue_general': 'tissue'})
    )
ts

CPU times: user 2.97 s, sys: 1.68 s, total: 4.64 s
Wall time: 1.22 s


Unnamed: 0,soma_joinid,tissue
0,4308,blood
1,4309,blood
2,4310,blood
3,4311,blood
4,4312,blood
...,...,...
44265927,74322505,brain
44265928,74322506,brain
44265929,74322507,brain
44265930,74322508,brain


In [5]:
t = ts.tissue
t

0           blood
1           blood
2           blood
3           blood
4           blood
            ...  
44265927    brain
44265928    brain
44265929    brain
44265930    brain
44265931    brain
Name: tissue, Length: 44265932, dtype: category
Categories (55, object): ['adipose tissue', 'adrenal gland', 'axilla', 'bladder organ', ..., 'ureter', 'uterus', 'vasculature', 'yolk sac']

In [6]:
runs = (
    t
    .groupby((t != t.shift()).cumsum().rename(None))
    .agg(['first', 'size'])
    .rename(columns={'first': t.name, 'size': 'len'})
)
runs

Unnamed: 0,tissue,len
1,blood,1324
2,bone marrow,8357
3,blood,7145
4,liver,3240
5,respiratory system,2398
...,...,...
1239870,lung,6
1239871,respiratory system,1
1239872,nose,1
1239873,lung,10


In [7]:
runs.len.value_counts()

len
1         759001
2         177593
3          81567
4          51947
5          37503
           ...  
125117         1
77536          1
195739         1
167598         1
2584           1
Name: count, Length: 1119, dtype: int64

In [8]:
trh = (
    runs
    .groupby('tissue')['len']
    .apply(lambda s: (
        s
        .value_counts(sort=False)
        .rename_axis('len')
    ))
    .rename('num')
    .reset_index()
)
trh

  .groupby('tissue')['len']


Unnamed: 0,tissue,len,num
0,adipose tissue,1372,1
1,adipose tissue,166149,1
2,adipose tissue,12,1
3,adipose tissue,448,1
4,adipose tissue,14,2
...,...,...,...
1659,yolk sac,718,1
1660,yolk sac,1524,1
1661,yolk sac,2499,1
1662,yolk sac,2772,1


In [9]:
trh.num.max()

209328

In [10]:
ei = experiment_dataset._exp_iter
ois = ei._obs_joinids
ois.shape

(53751,)

In [11]:
ois

array([  885329,   885330,   885331, ..., 19163729, 19163730, 19163731])

Are the `obs_joinids` one contiguous autoinc sequence?

In [12]:
ois[-1] + 1 - ois[0] == len(ois)

False

In [13]:
opi = list(ei._create_obs_joinids_partition())
len(opi)

1

Are the partition IDs just a shuffle of the `obs_joinids`?

In [14]:
(sorted(opi[0]) == ois).all()

True

Idxs where autoinc is broken:

In [15]:
import numpy as np
import pandas as pd

p = opi[0]
s = pd.Series(p)
s

0          889169
1          889170
2          889171
3          889172
4          889173
           ...   
53746    18905610
53747    18905611
53748    18905612
53749    18905613
53750    18905614
Length: 53751, dtype: int64

In [18]:
def find_runs(s):
    # Find where values don't increment by 1
    breaks = (s.diff() != 1) | (s.shift().isna())
    run_starts = s.index[breaks]
    run_ends = s.index[breaks.shift(-1).fillna(True)]
    
    df = pd.DataFrame({
        'idx0': run_starts,
        'idx1': run_ends,
        'id0': s[run_starts].values,
        'id1': s[run_ends].values,
    })
    df['len'] = df.idx1 + 1 - df.idx0
    return df

In [19]:
rs = find_runs(s)
rs

  run_ends = s.index[breaks.shift(-1).fillna(True)]


Unnamed: 0,idx0,idx1,id0,id1,len
0,0,63,889169,889232,64
1,64,127,2559714,2559777,64
2,128,191,2544674,2544737,64
3,192,255,892817,892880,64
4,256,319,2550114,2550177,64
...,...,...,...,...,...
835,53431,53494,904593,904656,64
836,53495,53558,2563874,2563937,64
837,53559,53622,2567010,2567073,64
838,53623,53686,2560418,2560481,64


In [22]:
rs.len.value_counts().sort_index()

len
4        1
16       1
26       1
38       1
48       1
60       1
63       9
64     822
128      3
Name: count, dtype: int64

In [71]:
((rs.idx1 - rs.idx0) == (rs.id1 - rs.id0)).all()

True

In [62]:
si = s[s.ne(s.shift() + 1)].rename('id').rename_axis('idx')
si

idx
0        18777441
64       18779041
128      18772065
192      18780961
256      18766241
           ...   
19375    18783393
19439    18770849
19503    18765089
19567    18769377
19631    18773537
Name: id, Length: 308, dtype: int64

In [66]:
s[si.iloc[1:].index - 1]

idx
63       18777504
127      18779104
191      18772128
255      18781024
319      18766304
           ...   
19374    18775456
19438    18783456
19502    18770912
19566    18765152
19630    18769440
Length: 307, dtype: int64

In [None]:
# PyTorch
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()  # noqa: UP008
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs
    

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)
        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

In [None]:
# Lightning
import pytorch_lightning as pl

class LogisticRegressionLightning(pl.LightningModule):
    def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):
        super(LogisticRegressionLightning, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.cell_type_encoder = cell_type_encoder
        self.learning_rate = learning_rate
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

    def training_step(self, batch, batch_idx):
        X_batch, y_batch = batch
        # X_batch = X_batch.float()
        X_batch = torch.from_numpy(X_batch).float().to(self.device)

        # Perform prediction
        outputs = self(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute loss
        y_batch = torch.from_numpy(
            self.cell_type_encoder.transform(y_batch["cell_type"])
        ).to(self.device)
        loss = self.loss_fn(outputs, y_batch.long())

        # Compute accuracy
        train_correct = (predictions == y_batch).sum().item()
        train_accuracy = train_correct / len(predictions)

        # Log loss and accuracy
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
if cpu or not torch.cuda.is_available():
    device = "cpu"
else:
    device = "cuda"
device = torch.device(device)
input_dim = experiment_dataset.shape[1]
output_dim = len(cell_type_encoder.classes_)

dl_kwargs = {} if workers is None else dict(num_workers=workers, persistent_workers=True)
train_dataloader = soma_ml.experiment_dataloader(experiment_dataset, **dl_kwargs)

if lightning:
    model = LogisticRegressionLightning(input_dim, output_dim, cell_type_encoder=cell_type_encoder)
    trainer = pl.Trainer(
        max_epochs=n_epochs,
        strategy="auto" if cpu else "ddp_notebook",
        accelerator="cpu" if cpu else "gpu",
        devices=1 if cpu else workers or 1,
        sync_batchnorm=True if not cpu and workers and workers > 1 else False,
        deterministic=True,
        max_time=None,
        enable_progress_bar=not is_papermill,
    )
    torch.set_float32_matmul_precision("high")
else:
    model = LogisticRegression(input_dim, output_dim).to(device)
    if workers > 1:
        gpus = torch.cuda.device_count()
        if gpus < workers:
            raise ValueError(f"Requested {workers=} but only found {gpus=}")
        import torch.nn as nn
        model = nn.DataParallel(model, device_ids=list(range(workers)))
        print(f"Parallelizing model with {workers=}")
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
%%time
if lightning:
    trainer.fit(model, train_dataloaders=train_dataloader)
else:
    for epoch in range(n_epochs):
        if workers is not None:
            experiment_dataset.set_epoch(epoch)
        train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")

In [None]:
# TODO: split train/test
test_dataloader = experiment_dataloader(experiment_dataset, **dl_kwargs)
X_batch, y_batch = next(iter(test_dataloader))
X_batch = torch.from_numpy(X_batch)
y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))

In [None]:
import pandas as pd

model.eval()
model.to(device)
outputs = model(X_batch.to(device))
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())

cmp_df = pd.DataFrame({
    "actual cell type": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),
    "predicted cell type": predicted_cell_types,
})
right, wrong = (cmp_df['actual cell type'] == cmp_df['predicted cell type']).value_counts().values
print('Accuracy: %.1f%% (%d correct, %d incorrect)' % (100 * right / len(cmp_df), right, wrong))
pd.crosstab(cmp_df['actual cell type'], cmp_df['predicted cell type']).replace(0, '')