In [1]:
# 导入数据集 空间转录组细胞数据分析初赛数据集
import wfio
import pandas as pd

pd.set_option('display.max_rows', 20)
_INPUT = '{"type":15,"uri":"sample_data/85/task-85-r2wzclmecm"}'
# 读取并返回对应的Dataframe
# 参数as_spark: 为True返回Spark DataFrame，为False返回Pandas DataFrame，默认为False
    
df = wfio.read_dataframe(_INPUT,as_spark = False)

In [2]:
import os
import time
import anndata
import argparse
import os.path as osp
import pandas as pd
import warnings
import numpy as np
import scanpy as sc
import scipy.sparse as sp
import torch
import torch_geometric
import matplotlib.pyplot as plt
from torch_geometric.loader import ClusterLoader, ClusterData

from graph_model import SpatialModel
from utils import load_data, preprocessing

warnings.filterwarnings("ignore")

def batch_dataloader(dfs, pca_dims=500, k_graph=30, edge_weight=True, num_parts=128, batch_size=32, seed=1234):
    torch.manual_seed(seed)
    adata = load_data(dfs)
    print('Data: %d cells × %d genes.' % (adata.shape[0], adata.shape[1]))
    adata = preprocessing(adata,
                          filter_mt=False,
                          norm_and_log=True,
                          z_score=True)

    if sp.issparse(adata.X):
        adata.X = adata.X.toarray()
    gene_tensor = torch.Tensor(adata.X)
    u, s, v = torch.pca_lowrank(gene_tensor, q=pca_dims)
    gene_tensor = torch.matmul(gene_tensor, v)
    adata.obsm["X_pca"] = gene_tensor.numpy()

    cell_coo = torch.Tensor(adata.obsm["spatial"])

    data = torch_geometric.data.Data(x=gene_tensor, pos=cell_coo)
    data = torch_geometric.transforms.KNNGraph(k=k_graph, loop=False)(data)

    # make distance as edge weights.
    if edge_weight:
        data = torch_geometric.transforms.Distance()(data)
        data.edge_weight = 1 - data.edge_attr[:, 0]
    else:
        data.edge_weight = torch.ones(data.edge_index.size(1))
    
    data.idx = torch.arange(adata.shape[0])

    cluster_data = ClusterData(data, num_parts=num_parts)
    train_loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True)

    return data, adata, train_loader


class Trainer:
    def __init__(self, input_dims):
        self.input_dims = input_dims
        self.device = torch.device('cpu')

        gae_dims = [32, 8]
        dae_dims = [100, 20]
        self.model = SpatialModel(input_dims=self.input_dims,
                                  gae_dims=gae_dims,
                                  dae_dims=dae_dims).to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01, weight_decay=1e-4)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1, gamma=1.0)
        self.scaler = torch.cuda.amp.GradScaler()

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

    def save_checkpoint(self, path):
        state = {'model': self.model.state_dict(),
                 'optimizer': self.optimizer.state_dict()}
        torch.save(state, path)

    def train(self, train_loader, epochs=200, w_dae=1.0, w_gae=1.0):
        self.model.train()
        start_time = time.time()
        for epoch in range(1, epochs + 1):
            train_loss = 0
            for batch, data in enumerate(train_loader, start=1):
                data = data.to(self.device, non_blocking=True)
                inputs = data.x
                edge_index = data.edge_index
                edge_weight = data.edge_weight
                with torch.cuda.amp.autocast():
                    feat, dae_loss, gae_loss = self.model(inputs, edge_index, edge_weight)
                    loss = w_dae * dae_loss + w_gae * gae_loss
                train_loss += loss.item()
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.scheduler.step()
                train_loss = train_loss / len(train_loader)
                process_time = time.time() - start_time
                print("  [ Epoch %d\t Batch %d ] Loss: %.5f, Time: %.2f s" % (epoch, batch, train_loss, process_time))

    def inference(self, test_loader, cell_nums):
        self.model.eval()
        output = np.zeros((cell_nums, self.model.feat_dims))
        for data in test_loader:
            data = data.to(self.device)
            idx = data.idx.detach().cpu().numpy()
            feat, _, _ = self.model(data.x, data.edge_index, data.edge_weight)
            output[idx] = feat.detach().cpu().numpy()
        return output


def cluster_block(feat, adata, indices, save_path, n_neighbors=30, resolution=0.5):
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    print('clustering ......')
    st = time.time()
    adata_feat = anndata.AnnData(feat[indices], obs=pd.DataFrame(index=map(str, indices)))
    adata_feat.obsm["spatial"] = adata.obsm["spatial"][indices]
    adata_feat.obsm["X_input"] = adata.obsm["X_pca"][indices]
    sc.pp.neighbors(adata_feat, n_neighbors=n_neighbors)
    sc.tl.leiden(adata_feat, resolution=resolution)
    clusters = adata_feat.obs["leiden"].tolist()
    results = pd.DataFrame({"id": adata[indices].obs.index.tolist(), "label": clusters})
    tag = '5'  # 初赛数据集中的tag都是5
    results['id'] = results['id'].apply(lambda x: tag+'_'+str(x))
    results.to_csv(osp.join(save_path, "submit.csv"), index=False)
    print("cluster results has been saved in path: ", save_path)
    cost_time_cluster = time.time()-st
    print('clustering finished, cost time(s): ',cost_time_cluster)

    return cost_time_cluster

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
epochs = 20
datas = [df]
save_path = '../result'

print('epochs: ', epochs)

print('preparing data......')
st = time.time()
data, adata, train_loader = batch_dataloader(datas)
cost_time_prepare_data = time.time()-st
print('preparing data finished, cost time(s): ', cost_time_prepare_data)

print('training ......')
st = time.time()
trainer = Trainer(input_dims=data.num_features)
trainer.train(train_loader=train_loader, epochs=epochs)
feat = trainer.inference(train_loader, adata.shape[0])
cost_time_train = time.time()-st
print('training finished, cost time(s): ',cost_time_train)

cost_time_cluster = cluster_block(feat=feat, adata=adata, indices=list(range(feat.shape[0])),
                            save_path=save_path, n_neighbors=30, resolution=0.5)

params = {"cost_time_prepare_data": str(cost_time_prepare_data), 
            "cost_time_train": str(cost_time_train), 
            "cost_time_cluster": str(cost_time_cluster)}


# AI靶场平台使用否工具包，用于记录全量运行参数和评估指标。
from wf_analyse.analyse import wflogger
wflogger.log_params(params=params)

epochs:  20
preparing data......
Data: 87845 cells × 27754 genes.


Computing METIS partitioning...
Done!


preparing data finished, cost time(s):  769.4470381736755
training ......
  [ Epoch 1	 Batch 1 ] Loss: 1.32103, Time: 5.30 s
  [ Epoch 1	 Batch 2 ] Loss: 1.56710, Time: 10.08 s
  [ Epoch 1	 Batch 3 ] Loss: 1.64284, Time: 14.68 s
  [ Epoch 1	 Batch 4 ] Loss: 1.53608, Time: 19.28 s
  [ Epoch 2	 Batch 1 ] Loss: 1.20800, Time: 23.92 s
  [ Epoch 2	 Batch 2 ] Loss: 1.43579, Time: 28.47 s
  [ Epoch 2	 Batch 3 ] Loss: 1.47239, Time: 33.58 s
  [ Epoch 2	 Batch 4 ] Loss: 1.47082, Time: 38.27 s
  [ Epoch 3	 Batch 1 ] Loss: 1.07128, Time: 43.09 s
  [ Epoch 3	 Batch 2 ] Loss: 1.48448, Time: 48.08 s
  [ Epoch 3	 Batch 3 ] Loss: 1.37624, Time: 52.88 s
  [ Epoch 3	 Batch 4 ] Loss: 1.47025, Time: 57.47 s
  [ Epoch 4	 Batch 1 ] Loss: 1.02382, Time: 62.38 s
  [ Epoch 4	 Batch 2 ] Loss: 1.39670, Time: 67.20 s
  [ Epoch 4	 Batch 3 ] Loss: 1.48581, Time: 73.98 s
  [ Epoch 4	 Batch 4 ] Loss: 1.42416, Time: 78.58 s
  [ Epoch 5	 Batch 1 ] Loss: 1.10072, Time: 83.41 s
  [ Epoch 5	 Batch 2 ] Loss: 1.26134, Time: