In [1]:
import os
os.chdir("..")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import imb
import pickle
dirname = "mouse_brain_sagittal_anterior"
GPU_ID = 1
with open(dirname + "/train_data.pkl", "rb") as file:
    train_data = pickle.load(file)
with open(dirname + "/test_data.pkl", "rb") as file:
    test_data = pickle.load(file)
import datapre as DP
DP.setup_seed(1)

In [2]:
import cupy as cp
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

In [3]:
with cp.cuda.Device(GPU_ID):
    pos_index, neg_index, marked_neg_index = imb.eliminate_BD_neg(train_data.feature, train_data.label, k = 20)
    mempool.free_all_blocks()
    pinned_mempool.free_all_blocks()
    marked_neg_index = cp.asnumpy(marked_neg_index)

import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
sc.settings.verbosity = 1
def spot_plot(adata, pair_index, pair_color):
    spa_pixel = adata.obsm["spatial"].copy()
    scalefactor = adata.uns["spatial"]['V1_Mouse_Brain_Sagittal_Anterior']["scalefactors"]["tissue_hires_scalef"] 
    pixels = np.apply_along_axis(
        lambda x : (spa_pixel[x] * scalefactor).reshape(-1), 
        1, 
        pair_index
    )
    _, ax = plt.subplots(constrained_layout = True, figsize = (8, 6))
    sc.pl.spatial(
        adata, 
        img_key = "hires", 
        size = 1.2,
        show = False, 
        ax = ax, 
        zorder = 1,
        color = "clusters"
    )
    for lines in pixels:
        _ = ax.plot(
            [lines[0], lines[2]], 
            [lines[1], lines[3]],
            alpha = 0.7,
            zorder = 2,
            color = pair_color
        )

adata = sc.datasets.visium_sge(sample_id = "V1_Mouse_Brain_Sagittal_Anterior")
adata.var_names_make_unique()
adata.var["mt"] = adata.var_names.str.startswith("mt-")
sc.pp.calculate_qc_metrics(adata, qc_vars = ["mt"], inplace = True)

with open(dirname + "/clusters.pkl", "rb") as file:
    adata.obs["clusters"] = pickle.load(file)

with cp.cuda.Device(GPU_ID):
    spot_plot(adata, train_data.pair_index[[20]], pair_color = "black")

In [4]:
marked_feature = train_data.get_feature(train_data.pair_index_son[marked_neg_index], copy = True)
marked_label = train_data.get_label(train_data.data_index[marked_neg_index], copy = True)

In [5]:
train_data.pop(marked_neg_index)
train_data.mirror_copy()
train_data.get_feature()
train_data.get_label()

In [6]:
from model import nn_model
import evaluation as eval
model = nn_model.NeuralNetworkClassifier(batch_size = 128)
model.fit(train_data.feature, train_data.label)

Epoch [1/50]: 100%|██████████| 6980/6980 [00:31<00:00, 222.99it/s, train_loss=0.0074]
Valid: 100%|██████████| 776/776 [00:01<00:00, 455.45it/s, val_loss=0.0014]


Validation loss decreased (inf --> 0.001383)
----------------------------------------------------------------


Epoch [2/50]: 100%|██████████| 6980/6980 [00:28<00:00, 244.58it/s, train_loss=0.0032]
Valid: 100%|██████████| 776/776 [00:01<00:00, 439.60it/s, val_loss=0.0009]


Validation loss decreased (0.001383 --> 0.000942)
----------------------------------------------------------------


Epoch [3/50]: 100%|██████████| 6980/6980 [00:30<00:00, 229.34it/s, train_loss=0.0024]
Valid: 100%|██████████| 776/776 [00:01<00:00, 440.93it/s, val_loss=0.0009]


Validation loss decreased (0.000942 --> 0.000884)
----------------------------------------------------------------


Epoch [4/50]: 100%|██████████| 6980/6980 [00:30<00:00, 232.54it/s, train_loss=0.0021]
Valid: 100%|██████████| 776/776 [00:01<00:00, 468.85it/s, val_loss=0.0006]


Validation loss decreased (0.000884 --> 0.000632)
----------------------------------------------------------------


Epoch [5/50]: 100%|██████████| 6980/6980 [00:29<00:00, 236.30it/s, train_loss=0.0018]
Valid: 100%|██████████| 776/776 [00:01<00:00, 455.77it/s, val_loss=0.0012]


EarlyStopping counter: 1 out of 7
----------------------------------------------------------------


Epoch [6/50]: 100%|██████████| 6980/6980 [00:29<00:00, 233.71it/s, train_loss=0.0019]
Valid: 100%|██████████| 776/776 [00:01<00:00, 485.85it/s, val_loss=0.0009]


EarlyStopping counter: 2 out of 7
----------------------------------------------------------------


Epoch [7/50]: 100%|██████████| 6980/6980 [00:29<00:00, 232.71it/s, train_loss=0.0018]
Valid: 100%|██████████| 776/776 [00:01<00:00, 465.89it/s, val_loss=0.0004]


Validation loss decreased (0.000632 --> 0.000366)
----------------------------------------------------------------


Epoch [8/50]: 100%|██████████| 6980/6980 [00:29<00:00, 239.70it/s, train_loss=0.0016]
Valid: 100%|██████████| 776/776 [00:01<00:00, 455.94it/s, val_loss=0.001]


EarlyStopping counter: 1 out of 7
----------------------------------------------------------------


Epoch [9/50]: 100%|██████████| 6980/6980 [00:30<00:00, 231.50it/s, train_loss=0.0015]
Valid: 100%|██████████| 776/776 [00:01<00:00, 484.16it/s, val_loss=0.0003]


Validation loss decreased (0.000366 --> 0.000342)
----------------------------------------------------------------


Epoch [10/50]: 100%|██████████| 6980/6980 [00:29<00:00, 233.74it/s, train_loss=0.0015]
Valid: 100%|██████████| 776/776 [00:01<00:00, 450.89it/s, val_loss=0.0009]


EarlyStopping counter: 1 out of 7
----------------------------------------------------------------


Epoch [11/50]: 100%|██████████| 6980/6980 [00:28<00:00, 243.05it/s, train_loss=0.0016]
Valid: 100%|██████████| 776/776 [00:01<00:00, 454.84it/s, val_loss=0.0006]


EarlyStopping counter: 2 out of 7
----------------------------------------------------------------


Epoch [12/50]: 100%|██████████| 6980/6980 [00:28<00:00, 241.58it/s, train_loss=0.0017]
Valid: 100%|██████████| 776/776 [00:01<00:00, 493.09it/s, val_loss=0.0006]


EarlyStopping counter: 3 out of 7
----------------------------------------------------------------


Epoch [13/50]: 100%|██████████| 6980/6980 [00:30<00:00, 228.74it/s, train_loss=0.0017]
Valid: 100%|██████████| 776/776 [00:01<00:00, 438.23it/s, val_loss=0.0004]


EarlyStopping counter: 4 out of 7
----------------------------------------------------------------


Epoch [14/50]: 100%|██████████| 6980/6980 [00:28<00:00, 247.53it/s, train_loss=0.0015]
Valid: 100%|██████████| 776/776 [00:01<00:00, 441.69it/s, val_loss=0.0006]


EarlyStopping counter: 5 out of 7
----------------------------------------------------------------


Epoch [15/50]: 100%|██████████| 6980/6980 [00:29<00:00, 233.83it/s, train_loss=0.0016]
Valid: 100%|██████████| 776/776 [00:01<00:00, 446.42it/s, val_loss=0.0007]


EarlyStopping counter: 6 out of 7
----------------------------------------------------------------


Epoch [16/50]: 100%|██████████| 6980/6980 [00:30<00:00, 230.03it/s, train_loss=0.0016]
Valid: 100%|██████████| 776/776 [00:01<00:00, 438.67it/s, val_loss=0.0006]


EarlyStopping counter: 7 out of 7
Early stopping
----------------------------------------------------------------


In [7]:
# train_data(无marked_neg_index)
predprob = model.predict_proba(train_data.feature)
eval.evaluate(train_data.label, predprob, verbose = False)

{'Accuracy': 0.9999153806711321,
 'Precision': 0.9928415191887056,
 'Recall': 0.9987997599519904,
 'MCC': 0.9957735335277492,
 'F1_Score': 0.9958117271639411,
 'AUC': 0.9999994039612538,
 'Average Precision': 0.9999426882246032,
 'confusion_matrix': array([[982611,     72],
        [    12,   9986]])}

In [8]:
# marked_neg_index
predprob = model.predict_proba(marked_feature)
eval.evaluate(marked_label, predprob, verbose = False)

  Recall = TP / (TP + FN)


{'Accuracy': 0.7340370882906198,
 'Precision': 0.0,
 'Recall': nan,
 'MCC': 0.0,
 'F1_Score': nan,
 'AUC': nan,
 'Average Precision': -0.0,
 'confusion_matrix': array([[980239, 355169],
        [     0,      0]])}

In [9]:
# 全部test_data
predprob = model.predict_proba(test_data.feature)
eval.evaluate(test_data.label, predprob, verbose = False)

{'Accuracy': 0.843617879730466,
 'Precision': 0.013358861668334712,
 'Recall': 1.0,
 'MCC': 0.10613831444963816,
 'F1_Score': 0.026365510133974583,
 'AUC': 0.9610026235317097,
 'Average Precision': 0.030547853556984434,
 'confusion_matrix': array([[122010,  22674],
        [     0,    307]])}

In [10]:
with cp.cuda.Device(GPU_ID):
    pos_index, neg_index, marked_neg_index = imb.eliminate_BD_neg(test_data.feature, test_data.label, k = 20)
    mempool.free_all_blocks()
    pinned_mempool.free_all_blocks()
    marked_neg_index = cp.asnumpy(marked_neg_index)

In [11]:
test_data.pop(marked_neg_index)
test_data.mirror_copy()
test_data.get_feature()
test_data.get_label()

In [12]:
# test_data(去除marked_neg_index)
predprob = model.predict_proba(test_data.feature)
eval.evaluate(test_data.label, predprob, verbose = True)

{'Accuracy': 0.9979214989306263,
 'Precision': 0.8989751098096632,
 'Recall': 1.0,
 'MCC': 0.9471385211978792,
 'F1_Score': 0.9468003084040092,
 'AUC': 0.9995627553426323,
 'Average Precision': 0.9574144760202657,
 'confusion_matrix': array([[32514,    69],
        [    0,   614]])}

使用全部数据训练：所有样本预测成负样本
使用marked_neg_index方法：选择后的训练样本训练模型，能够很好的预测选择后的训练样本和测试样本，但是对于marked_neg_index的样本（训练和预测）预测很差

可以调整k，看上面marked_neg_index一例，大部分样本都能预测对，还可以调整