In [1]:
import warnings
warnings.filterwarnings('ignore')
import os
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.metrics import f1_score
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from scipy import sparse
from module import *
def setup_seed(seed):

    np.random.seed(seed) 
    random.seed(seed)
    
    os.environ['PYTHONHASHSEED'] = str(seed)  
    
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.enabled = False  
    torch.backends.cudnn.benchmark = False  
    torch.set_float32_matmul_precision('high')
    print("seed set ok!")
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


class scDataset(Dataset):
    def __init__(self,index):
        # self.path='./adata_process'
        self.path='/home/share/huadjyin/home/zhouxuanchi/HIV/atac_to_gene_new_data_0218/adata_process'
        self.index_list=index
    def __len__(self):
        return len(self.index_list)
    def get_np_array(self, filename):
        return np.load(os.path.join(self.path, filename))
    def __getitem__(self, idx):
        index_name=self.index_list[idx]
        array_idx=self.get_np_array(str(index_name)+'.npy')
        gene = torch.tensor(array_idx[:582], dtype=torch.bfloat16)
        peak = torch.tensor(array_idx[582:], dtype=torch.bfloat16)

        mask=torch.tensor((gene != 0), dtype=torch.bfloat16)
        return gene, peak, mask

    
import pickle
def save_data(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"Data saved to {filename}")
    
# 定义一个函数，用于加载文件中的数据
# 定义一个函数，用于加载文件中的数据
    # 打开文件，以二进制模式读取
def load_data(filename):
        # 使用pickle模块加载文件中的数据
    with open(filename, 'rb') as f:
    # 返回加载的数据
        data = pickle.load(f)
    return data


class Peak2GeneModel(pl.LightningModule):
    def __init__(self, input_dim=64,hidden_dim=512,out_features=582):
        super().__init__()
        # 定义峰编码器
        self.peak_encoder= TokenizedFAEncoder(5583, 64, True, 7, 0.1, 'layernorm')
        # 定义投影层
        self.projection_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        # 定义解码器
        self.decoder = GatedMLP(64, out_features=out_features)
        self.matrix=torch.load('./data/mask_mat.pt')
    def forward(self,peak,mask):
        matrix=torch.tensor(self.matrix,device=mask.device)
        mask=torch.mm(mask,matrix)
        mask = torch.cat((torch.zeros(size=(mask.shape[0],1), dtype=torch.bfloat16,device=mask.device), mask), dim=1)
        # 计算掩码
        peak_embed = self.peak_encoder(peak,mask)
        peak_embed = self.projection_layer(peak_embed)
        if mask is not None:
            m = mask.unsqueeze(-1).float()
            peak_embed = (peak_embed * m).sum(1) / m.sum(1) 
        pred_gene = self.decoder(peak_embed)
        return pred_gene
    def training_step(self, batch, batch_idx):
        gene,peak,mask=batch
        pred_gene=self(peak,mask)
        loss = F.mse_loss(pred_gene.view_as(gene), gene)
        self.log('train_loss', loss)
        return loss
    def validation_step(self, batch, batch_idx):
        gene,peak,mask=batch
        pred_gene=self(peak,mask)
        loss = F.mse_loss(pred_gene.view_as(gene), gene)
        self.log('val_loss', loss)
        return loss
    def predict_step(self, batch, batch_idx):
        gene,peak,mask=batch
        pred_gene=self(peak,mask)
        return gene,pred_gene
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        step_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
        optim_dict = {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': step_lr_scheduler,
                'monitor': 'val_loss',
            }
        }
        return optim_dict

In [4]:
dict_fen=load_data('./data/fen.pkl')

X_test = dict_fen['test']
test_dataset = scDataset(X_test)
trainer = pl.Trainer(
    accelerator='gpu',  
    devices=[3],
    precision='bf16-mixed',
)
model = Peak2GeneModel.load_from_checkpoint("./model/hiv_model-epoch=98-val_loss=0.4079.ckpt",map_location='cpu')

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8)

In [None]:
pred_out= trainer.predict(model, test_loader)
import numpy as np
import torch
gene = np.concatenate([p.to(torch.float16).cpu().numpy() for p,_ in data])
pred_gene = np.concatenate([p.to(torch.float16).cpu().numpy() for _,p in data])

In [None]:
dict_gene={'label_gene':gene,'pred_gene':pred_gene}
save_data(dict_gene,'./data/dict_gene.pkl)

In [3]:
import pickle
def save_data(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"Data saved to {filename}")
    
# 定义一个函数，用于加载文件中的数据
# 定义一个函数，用于加载文件中的数据
    # 打开文件，以二进制模式读取
def load_data(filename):
        # 使用pickle模块加载文件中的数据
    with open(filename, 'rb') as f:
    # 返回加载的数据
        data = pickle.load(f)
    return data

dict_gene=load_data('/home/share/huadjyin/home/zhouxuanchi/HIV_code/peak2gene/data/dict_gene.pkl')

In [6]:
fen=load_data('/home/share/huadjyin/home/zhouxuanchi/HIV_code/peak2gene/data/fen.pkl')

In [28]:
import numpy as np
all_data = np.concatenate((dict_gene['label_gene'], dict_gene['pred_gene']), axis=1)

In [9]:
import scanpy as sc

In [11]:
adata=sc.read_h5ad('./data/adata_process.h5ad',backed='r')

In [13]:
adata=adata[fen['test']]

In [18]:
adata.obs.loc[:,['rna_cellname','rna_sample','rna_stage','celltype_L1','celltype_L2','celltype_L3']]

Unnamed: 0,rna_cellname,rna_sample,rna_stage,celltype_L1,celltype_L2,celltype_L3
matched_cell_2407805,PD-H272-2-CELL2004_N1,PD-H272,IRs,CD8+ T & unconvensional T,CD8 CTL,CD8_CTL-GZMB
matched_cell_2077838,PD-H231-1-CELL6935_N1,PD-H231,INRs,B,Naive B,Naive_B-TCL1A
matched_cell_1341115,PD-H286-1-CELL3377_N1,PD-H286,INRs,Myeloid,cMono,cMono-IFI44L
matched_cell_2113702,PD-H276-2-CELL10141_N1,PD-H276,INRs,B,Naive B,Naive_B-TCL1A
matched_cell_657517,PD-H308-2-CELL2326_N2,PD-H308,IRs,CD4+ T,CD4 Naive T,CD4_Naive_T-CCR7
...,...,...,...,...,...,...
matched_cell_865626,HD-H147-2-CELL1306_N2,HD-H147,HDs,CD8+ T & unconvensional T,CD8 Tem,CD8_Tem-GZMK
matched_cell_408100,HD-H157-1-CELL2739_N1,HD-H157,HDs,CD4+ T,CD4 Tcm,CD4_Tcm-GPR183
matched_cell_1717895,HD-H140-2-CELL7623_N1,HD-H140,HDs,CD8+ T & unconvensional T,NKT,NKT-NCR1
matched_cell_515268,HD-H313-1-CELL1987_N2,HD-H313,HDs,CD4+ T,CD4 Naive T,CD4_Naive_T-CCR7


In [38]:
import anndata as ad
adata_output = ad.AnnData(all_data,obs=adata.obs.loc[:,['rna_cellname','rna_sample','rna_stage','celltype_L1','celltype_L2','celltype_L3']])

In [39]:
adata_output.var_names=adata.var_names[:582].tolist()+['pred_'+i for i in adata.var_names[:582]]

In [43]:
sc.write('./data/adata_pred_output.h5ad', adata_output)