In [None]:
# Note: this requires the `dev` branch of Slideflow as of 1/23/23
import multiprocessing as mp
import slideflow as sf
print('Slideflow version:', sf.__version__)

from slideflow.simclr import SlideflowBuilder
from slideflow.simclr import run_simclr

In [None]:
# Load a Slideflow project that already has tiles extracted
P = sf.Project('/mnt/data/projects/CYTOLOGY')

# Next, load a TFRecord dataset, and clip it to maximum of 5000 tiles per slide (optional)
dataset = P.dataset(tile_px=96, tile_um='40x').clip(5000)

# Load the ground-truth labels for each slide
labels, unique_labels = dataset.labels('benign_malignant')

In [None]:
# Next, create or load a validation dataset.
train_dts, val_dts = dataset.train_val_split(
    'categorical',       # Type of model 
    labels=labels,       # Ground truth labels
    val_strategy='fixed',# Not a cross-fold split, just single split
    val_fraction=0.2,    # Fraction of data for validation
    splits='/mnt/data/projects/CYTOLOGY/simclr_splits.json', # Save directory for split
)

print("Training dataset size: ", train_dts.num_tiles)
print("Validation dataset size: ", val_dts.num_tiles)

In [None]:
# Create dataset builder, which SimCLR will use to create
# the input pipeline for training
builder = SlideflowBuilder(
    train_dts=train_dts.balance(strategy='slide'),
    val_dts=val_dts.balance(strategy='slide'),
    labels=labels,
    num_classes=2
)

In [None]:
# Set up the SimCLR flags
simclr_flags = dict(
    mode='train_then_eval',
    train_mode='pretrain',
    train_batch_size=256,
    temperature=0.1,
    learning_rate=0.075,
    learning_rate_scaling='sqrt',
    weight_decay=1e-4,
    train_epochs=100,
    image_size=96,
    model_dir='/mnt/data/tmp/simclr_cytology_full_k1',
    checkpoint_epochs=10,
    use_tpu=False
)

In [None]:
# Train the SimCLR model
run_simclr(builder, flags=simclr_flags)