In [39]:
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

sc.logging.print_header()
print(f"squidpy=={sq.__version__}")
print(f"tensorflow=={tf.__version__}")

%load_ext autoreload
%autoreload 2
%load_ext lab_black

scanpy==1.6.0 anndata==0.7.5 umap==0.4.6 numpy==1.18.5 scipy==1.5.4 pandas==1.1.4 scikit-learn==0.23.2 statsmodels==0.12.1 python-igraph==0.8.3 leidenalg==0.8.2
squidpy==0.0.0
tensorflow==2.3.0-dev20200610
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black


In [2]:
adata = sq.datasets.visium_hne_adata()
img = sq.datasets.visium_hne_image()

In [3]:
# spots
spots = img.generate_spot_crops(adata)
spot_list = [s[0].data.to_array().values.squeeze(axis=0) for s in spots]

In [56]:
# categories
cluster_labels = adata.obs["cluster"]
classes = cluster_labels.unique().shape[0]
cluster_map = {v: i for i, v in enumerate(cluster_labels.cat.categories.values)}
labels = np.array([cluster_map[c] for c in cluster_labels], dtype=np.uint8)
labels_ohe = tf.one_hot(labels, depth=classes, dtype=tf.float32).numpy()

In [59]:
# split labels
train_idx, test_idx = train_test_split(
    np.arange(cluster_labels.shape[0]),
    test_size=0.2,
    stratify=cluster_labels,
    shuffle=True,
    random_state=42,
)

In [60]:
print(
    f"Train set : \n {adata[train_idx, :].obs.cluster.value_counts()} \n \n Test set: \n {adata[test_idx, :].obs.cluster.value_counts()}"
)

Train set : 
 Cortex_1                         227
Thalamus_1                       209
Cortex_2                         206
Cortex_3                         195
Fiber_tract                      181
Hippocampus                      178
Hypothalamus_1                   166
Thalamus_2                       154
Cortex_4                         131
Striatum                         122
Hypothalamus_2                   106
Cortex_5                         103
Lateral_ventricle                 84
Pyramidal_layer_dentate_gyrus     54
Pyramidal_layer                   34
Name: cluster, dtype: int64 
 
 Test set: 
 Cortex_1                         57
Thalamus_1                       52
Cortex_2                         51
Cortex_3                         49
Fiber_tract                      45
Hippocampus                      44
Hypothalamus_1                   42
Thalamus_2                       38
Cortex_4                         33
Striatum                         31
Hypothalamus_2             

In [61]:
def create_dataset(x: list, y: np.ndarray, idx: list):
    ds = tf.data.Dataset.from_tensor_slices(
        ([x[i] for i in idx], y[idx, :])
    )  # create dataset from lists
    ds = ds.shuffle(1000, reshuffle_each_iteration=True)  # shuffle
    ds = ds.batch(64)  # create batches
    # Create a data augmentation stage with horizontal flipping, rotations, zooms
    data_processing = tf.keras.Sequential(
        [
            preprocessing.Resizing(32, 32),
            preprocessing.Rescaling(1.0 / 255),
            preprocessing.RandomFlip(),
            preprocessing.RandomRotation(0.5),
            preprocessing.RandomContrast(0.5),
        ]
    )
    ds = ds.map(lambda x, y: (data_processing(x), y))

    return ds

In [62]:
train_ds = create_dataset(spot_list, labels_ohe, train_idx)
test_ds = create_dataset(spot_list, labels_ohe, test_idx)

In [63]:
# Create a model
input_shape = (32, 32, 3)

inputs = tf.keras.layers.Input(shape=input_shape)
outputs = tf.keras.applications.ResNet50(
    weights=None, input_shape=input_shape, classes=classes
)(inputs)
model = tf.keras.Model(inputs, outputs)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
)

In [64]:
model.fit(train_ds, validation_data=test_ds, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7ffc2b3dff70>