In [None]:
import spateo as st
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
import os
from scipy.sparse import lil_matrix, csr_matrix, vstack

## 1 Read 2 files

In [5]:
# read file
bin_file = 'data/Mouse_brain_Adult_GEM_bin1_sub.tsv'
image_file = 'data/Mouse_brain_Adult_sub.tif'
adatasub = st.io.read_bgi_agg(bin_file, image_file)

NameError: name 'st' is not defined

In [None]:
bin_size=3
n_neighbor=50
r_estimate=15
startx='0'
starty='0'
patchsize='0'

adatasub.layers['unspliced'] = adatasub.X
patchsizex = adatasub.X.shape[0]
patchsizey = adatasub.X.shape[1]

## 2 Run Watershed

In [None]:
# nucleus segmentation from staining image
fig, ax = plt.subplots(figsize=(8, 8), tight_layout=True)
st.cs.mask_nuclei_from_stain(adatasub, otsu_classes=4, otsu_index=1)
st.pl.imshow(adatasub, 'stain_mask', ax=ax)
st.cs.find_peaks_from_mask(adatasub, 'stain', 7)
st.cs.watershed(adatasub, 'stain', 5, out_layer='watershed_labels')

fig, ax = plt.subplots(figsize=(8, 8), tight_layout=True)
st.pl.imshow(adatasub, 'stain', save_show_or_return='return', ax=ax)
st.pl.imshow(adatasub, 'watershed_labels', labels=True, alpha=0.5, ax=ax)
plt.savefig('fig/watershed_labels' + startx + ':' + starty + ':' + '.png')
# adatasub.write('data/Mouse_brain_Adult_5800:8000:900:900.h5ad')
# print(adatasub)

adatasub.write('data/spots' + startx + ':' + starty + ':' + '.h5ad')

### 2.1 Real-value output(prob_map)

In [None]:
def watershed_with_probability(adatasub, stain_layer='stain', peak_distance=7, num_iterations=100):
    # Initial steps
    st.cs.mask_nuclei_from_stain(adatasub, otsu_classes=4, otsu_index=1)
    st.cs.find_peaks_from_mask(adatasub, stain_layer, peak_distance)

    markers = adatasub.obsm['peaks']
    gradient = filters.sobel(adatasub.X)

    # Run Watershed once to get the set of possible labels
    initial_segmentation = segmentation.watershed(gradient, markers=markers)
    unique_labels = np.unique(initial_segmentation)

    # Create a frequency map for each label
    freq_maps = {label: np.zeros_like(gradient, dtype=float) for label in unique_labels}

    for _ in range(num_iterations):
        perturbed_markers = markers + np.random.normal(scale=1, size=markers.shape)
        perturbed_markers = np.clip(perturbed_markers, 0, min(gradient.shape)-1)

        segmented = segmentation.watershed(gradient, markers=perturbed_markers)

        # Update the frequency maps
        for label in unique_labels:
            freq_maps[label] += (segmented == label)

    # Normalize each frequency map
    for label in unique_labels:
        freq_maps[label] /= num_iterations

    return freq_maps

# This will return a dictionary of frequency maps, one for each label.
freq_maps_dict = watershed_with_probability(adatasub)

## 3 Prepare data for Model1

In [None]:
watershed2x = {}
watershed2y = {}
for i in range(adatasub.layers['watershed_labels'].shape[0]):
    for j in range(adatasub.layers['watershed_labels'].shape[1]):
        if adatasub.layers['watershed_labels'][i, j] == 0:
            continue
        if adatasub.layers['watershed_labels'][i, j] in watershed2x:
            watershed2x[adatasub.layers['watershed_labels'][i, j]].append(i)
            watershed2y[adatasub.layers['watershed_labels'][i, j]].append(j)
        else:
            watershed2x[adatasub.layers['watershed_labels'][i, j]] = [i]
            watershed2y[adatasub.layers['watershed_labels'][i, j]] = [j]

watershed2center = {}
sizes = []
for nucleus in watershed2x:
    watershed2center[nucleus] = [np.mean(watershed2x[nucleus]), np.mean(watershed2y[nucleus])]
    sizes.append(len(watershed2x[nucleus]))
# print(np.min(sizes), np.max(sizes), np.mean(sizes))
# print('#nucleus', len(watershed2center))
# find xmin ymin
xall = []
yall = []
with open(bin_file) as fr:
    header = fr.readline()
    for line in fr:
        gene, x, y, count = line.split()
        xall.append(int(x))
        yall.append(int(y))
xmin = np.min(xall)
ymin = np.min(yall)
# print(np.min(xall), np.min(yall), np.max(xall), np.max(yall))

# find all the genes in the range
geneid = {}
genecnt = 0
id2gene = {}
with open(bin_file) as fr:
    header = fr.readline()
    for line in fr:
        gene, x, y, count = line.split()
        if gene not in geneid:
            geneid[gene] = genecnt
            id2gene[genecnt] = gene
            genecnt += 1

idx2exp = {}
downrs = bin_size
with open(bin_file) as fr:
    header = fr.readline()
    for line in fr:
        gene, x, y, count = line.split()
        x = int(x) - xmin
        y = int(y) - ymin
        if gene not in geneid:
            continue
        if int(x) < int(startx) or int(x) >= int(startx) + int(patchsizex) or int(y) < int(starty) or int(y) >= int(
                starty) + int(patchsizey):
            continue
        idx = int(math.floor((int(x) - int(startx)) / downrs) * math.ceil(patchsizey / downrs) + math.floor(
            (int(y) - int(starty)) / downrs))
        if idx not in idx2exp:
            idx2exp[idx] = {}
            idx2exp[idx][geneid[gene]] = int(count)
        elif geneid[gene] not in idx2exp[idx]:
            idx2exp[idx][geneid[gene]] = int(count)
        else:
            idx2exp[idx][geneid[gene]] += int(count)

all_exp_merged_bins = lil_matrix((int(math.ceil(patchsizex / downrs) * math.ceil(patchsizey / downrs)), genecnt),
                                 dtype=np.int8)
for idx in idx2exp:
    for gid in idx2exp[idx]:
        all_exp_merged_bins[idx, gid] = idx2exp[idx][gid]
        # print(idx, gid, idx2exp[idx][gid])
all_exp_merged_bins = all_exp_merged_bins.tocsr()
# print(all_exp_merged_bins.shape)

all_exp_merged_bins_ad = ad.AnnData(
    all_exp_merged_bins,
    obs=pd.DataFrame(index=[i for i in range(all_exp_merged_bins.shape[0])]),
    var=pd.DataFrame(index=[i for i in range(all_exp_merged_bins.shape[1])]),
)
sc.pp.highly_variable_genes(all_exp_merged_bins_ad, n_top_genes=2000, flavor='seurat_v3', span=1.0)
selected_index = all_exp_merged_bins_ad.var[all_exp_merged_bins_ad.var.highly_variable].index
selected_index = list(selected_index)
selected_index = [int(i) for i in selected_index]
with open('data/variable_genes' + startx + ':' + starty + ':' + patchsize + ':' + patchsize + '.txt', 'w') as fw:
    for id in selected_index:
        fw.write(id2gene[id] + '\n')

# check total gene counts
all_exp_merged_bins = all_exp_merged_bins.toarray()[:, selected_index]

In [None]:
distance_ratio_threshold = 0.3

import math
import numpy as np

# 前面已经定义的变量
# downrs, offsets, all_exp_merged_bins, adatasub, n_neighbor, patchsizey

# 初始化变量
x_test_tmp = []
x_test = []
x_test_pos = []

x_optimize_train = []
x_optimize_pos = []
x_optimize_labels = []
y_optimize_train = []
offsets = []
for dis in range(1, 11):
    for dy in range(-dis, dis + 1):
        offsets.append([-dis * downrs, dy * downrs])
    for dy in range(-dis, dis + 1):
        offsets.append([dis * downrs, dy * downrs])
    for dx in range(-dis + 1, dis):
        offsets.append([dx * downrs, -dis * downrs])
    for dx in range(-dis + 1, dis):
        offsets.append([dx * downrs, dis * downrs])
for i in range(adatasub.layers['watershed_labels'].shape[0]):
    for j in range(adatasub.layers['watershed_labels'].shape[1]):
        if (not i % downrs == 0) or (not j % downrs == 0):
            continue
        idx = int(math.floor(i / downrs) * math.ceil(patchsizey / downrs) + math.floor(j / downrs))
        label = adatasub.layers['watershed_labels'][i, j]

        # For Optimize_train
        if label > 0:
            x_optimize_sample = [all_exp_merged_bins[idx, :]]
            x_optimize_pos_sample = [[i, j]]
            y_optimize_sample = [watershed2center[label]]
            x_optimize_labels_sample = [label]
            for dx, dy in offsets:
                if len(x_optimize_sample) == n_neighbor:
                    break
                x = i + dx
                y = j + dy
                if 0 <= x < adatasub.layers['watershed_labels'].shape[0] and 0 <= y < adatasub.layers['watershed_labels'].shape[1]:
                    idx_nb = int(math.floor(x / downrs) * math.ceil(patchsizey / downrs) + math.floor(y / downrs))
                    if 0 <= idx_nb < all_exp_merged_bins.shape[0] and np.sum(all_exp_merged_bins[idx_nb, :]) > 0:
                        x_optimize_sample.append(all_exp_merged_bins[idx_nb, :])
                        x_optimize_pos_sample.append([x, y])
                        x_optimize_labels_sample.append(adatasub.layers['watershed_labels'][x, y])
                        y_optimize_sample.append(watershed2center.get(adatasub.layers['watershed_labels'][x, y], [-1, -1]))
            if len(x_optimize_sample) < n_neighbor:
                continue
            x_optimize_train.append(x_optimize_sample)
            x_optimize_pos.append(x_optimize_pos_sample)
            x_optimize_labels.append(x_optimize_labels_sample)
            y_optimize_train.append(y_optimize_sample)

        # For x_test
        elif label == 0:
            backgroud = True
            for nucleus in watershed2center:
                if (i - watershed2center[nucleus][0]) ** 2 + (j - watershed2center[nucleus][1]) ** 2 <= 900 or \
                        adatasub.layers['stain'][i, j] > 10:
                    backgroud = False
                    break
            if backgroud:
                continue  # Skip background

            x_test_sample = [all_exp_merged_bins[idx, :]]
            x_test_pos_sample = [[i, j]]
            for dx, dy in offsets:
                if len(x_test_sample) == n_neighbor:
                    break
                x = i + dx
                y = j + dy
                if 0 <= x < adatasub.layers['watershed_labels'].shape[0] and 0 <= y < adatasub.layers['watershed_labels'].shape[1]:
                    idx_nb = int(math.floor(x / downrs) * math.ceil(patchsizey / downrs) + math.floor(y / downrs))
                    if 0 <= idx_nb < all_exp_merged_bins.shape[0] and np.sum(all_exp_merged_bins[idx_nb, :]) > 0:
                        x_test_sample.append(all_exp_merged_bins[idx_nb, :])
                        x_test_pos_sample.append([x, y])
            if len(x_test_sample) < n_neighbor:
                continue
            x_test_tmp.append(x_test_sample)
            if len(x_test_tmp) > 500:
                x_test.extend(x_test_tmp)
                x_test_tmp = []
            x_test_pos.append(x_test_pos_sample)

## 4 Save Data

In [None]:
# 结束循环后
x_test.extend(x_test_tmp)
x_test = np.array(x_test)
x_test_pos = np.array(x_test_pos)

x_optimize_train = np.array(x_optimize_train)
x_optimize_pos = np.array(x_optimize_pos)
x_optimize_labels = np.array(x_optimize_labels)
y_optimize_train = np.array(y_optimize_train)

# 保存到文件
np.savez_compressed('data/x_test.npz',x_test=x_test)
np.savez_compressed('data/x_test.npz',x_test=x_test)
np.savez_compressed('data/x_optimize_train.npz', x_optimize_train=x_optimize_train)
np.savez_compressed('data/x_optimize_pos.npz', x_optimize_pos=x_optimize_pos)
np.savez_compressed('data/x_optimize_labels.npz', x_optimize_labels=x_optimize_labels)
np.savez_compressed('data/y_optimize_train.npz', y_optimize_train=y_optimize_train)

## 5 util

In [None]:
def dir_to_class(y_dir, class_num):
    y_dir_class = []
    for i in range(len(y_dir)):
        x, y = y_dir[i]
        if x == -9999:
            y_vec = np.zeros(class_num)
            y_dir_class.append(y_vec)
        else:
            if y == 0 and x > 0:
                deg = np.arctan(float('inf'))
            elif y == 0 and x < 0:
                deg = np.arctan(-float('inf'))
            elif y == 0 and x == 0:
                deg = np.arctan(0)
            else:
                deg = np.arctan((x / y))
            if (x > 0 and y < 0) or (x <= 0 and y < 0):
                deg += np.pi
            elif x < 0 and y >= 0:
                deg += 2 * np.pi
            cla = int(deg / (2 * np.pi / class_num))
            y_vec = np.zeros(class_num)
            y_vec[cla] = 1
            y_dir_class.append(y_vec)
    return np.array(y_dir_class)

In [None]:
# transformer to correct Watershed
class Model1(nn.Module):
    def __init__(self,gene_dim,dir_dim,hidden_dim,label_dim):
        super(Model1,self).__init__()

        self.dense_gene = nn.Linear(gene_dim, hidden_dim)
        self.dense_dir = nn.Linear(dir_dim, hidden_dim)
        # watershed label is a number
        self.dense_label = nn.Linear(label_dim, hidden_dim)
# ToDo(1):设计transformer模型具体结构
        self.transformer=nn.Transformer(hidden_dim,1)

    def forward(self,gene_data,dir_data,label_data):
        gene_data=self.dense_gene(gene_data)
        dir_data=self.dense_dir(dir_data)
        label_data=self.dense_label(label_data)

        fused_data=gene_data+dir_data+label_data

        out=self.transformer(fused_data)
        return out

### 4.1 Model1 training

In [None]:
model=Model1()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
# ToDo(2):训练流程需要细化，包含初始化-迭代-结束三个部分
for i in range(epochs):
    for trains,labels in enumerate(train_loader):
            outputs = model(trains)
            model.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

## 5 Postprocess and prepare data for model2

## 6...