In [1]:
from model import *
import sys

from attention import pHLA_attns_draw_save
from mutation import *
from utils import Logger, cut_peptide_to_specific_length

In [2]:
import math
from sklearn import metrics
from sklearn import preprocessing
import numpy as np
import pandas as pd
import re
import time
import datetime
import random
random.seed(1234)
from scipy import interp
import warnings
# warnings.filterwarnings("ignore")
from tqdm import tqdm
from pandas import Series, DataFrame
from typing import List
from torch import Tensor
from collections import Counter
from collections import OrderedDict
from functools import reduce
from tqdm import tqdm, trange
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

import difflib

seed = 19961231
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import os
import argparse
import logging
import sys

In [3]:
parser = argparse.ArgumentParser(usage = 'peptide-HLA-I binding prediction')
parser.add_argument('--peptide_file', type = str, help = 'the path of the .fasta file contains peptides')
parser.add_argument('--HLA_file', type = str, help = 'the path of the .fasta file contains sequence')
parser.add_argument('--threshold', type = float, default = 0.5, help = 'the threshold to define predicted binder, float from 0 - 1, the recommended value is 0.5')
parser.add_argument('--cut_peptide', type = bool, default = False, help = 'Whether to split peptides larger than cut_length?')
parser.add_argument('--cut_length', type = int, default = 9, help = 'if there is a peptide sequence length > 15, we will segment the peptide according the length you choose, from 8 - 15')
parser.add_argument('--output_attention', type = bool, default = False, help = 'Output the mutual influence of peptide and HLA on the binding?')
parser.add_argument('--output_heatmap', type = bool, default = False, help = 'Visualize the mutual influence of peptide and HLA on the binding?')
parser.add_argument('--output_mutation', type = bool, default = True, help = 'Whether to perform mutations with better affinity for each sample?')
parser.add_argument('--output_dir', type = str, help = 'The directory where the output results are stored.')

args = parser.parse_args(args = [])
args

Namespace(HLA_file=None, cut_length=9, cut_peptide=False, output_attention=False, output_dir=None, output_heatmap=False, output_mutation=True, peptide_file=None, threshold=0.5)

In [4]:
args.peptide_file = 'peptides.fasta'
args.HLA_file = 'hlas.fasta'
args.output_dir = './results/'
os.makedirs(args.output_dir, exist_ok=True)

cut_length = args.cut_length

In [5]:
if args.threshold <= 0 or args.threshold >= 1: 
    log = Logger('./error.log')
    log.logger.critical('The threshold invalid, please check whether it ranges from 0-1.')
    sys.exit(1)
if not args.peptide_file:
    log = Logger('./error.log')
    log.logger.critical('The peptide file is empty.')
    sys.exit(1)
if not args.HLA_file:
    log = Logger('./error.log')
    log.logger.critical('The HLA file is empty.')
    sys.exit(1)
if not args.output_dir:
    log = Logger('./error.log')
    log.logger.critical('Please fill the output file directory.')
    sys.exit(1)

# 读取文件

In [6]:

with open(args.peptide_file, 'r') as f:
    peptide_file = f.readlines()

with open(args.HLA_file, 'r') as f:
    HLA_file = f.readlines()
    
i = 0
ori_peptides, ori_HLA_names, ori_HLA_sequences = [], [], []
for pep, hla in zip(peptide_file, HLA_file):
    if i % 2 == 0:
        hla_name = hla.replace('>', '').replace('\t', '').replace('\n', '')
        ori_HLA_names.append(hla_name)
    if i % 2 == 1:
        hla_seq = str.upper(hla.replace('\n', '').replace('\t', ''))
        peptide = str.upper(pep.replace('\n', '').replace('\t', ''))
        ori_peptides.append(peptide)
        ori_HLA_sequences.append(hla_seq)
    i += 1

peptides, HLA_names, HLA_sequences = [], [], []
for pep, hla_name, hla_seq in zip(ori_peptides, ori_HLA_names, ori_HLA_sequences):
    
    if not (pep.isalpha() and hla.isalpha()): 
        continue
    if len(set(pep).difference(set('ARNDCQEGHILKMFPSTWYV'))) != 0:
        continue
    if len(set(hla_seq).difference(set('ARNDCQEGHILKMFPSTWYV'))) != 0:
        continue
            
    length = len(pep)
    if length < 15:
        if args.cut_peptide:
            if length > cut_length:
                cut_peptides = [pep] + [pep[i : i + cut_length] for i in range(length - cut_length + 1)]
                peptides.extend(cut_peptides)
                HLA_sequences.extend([hla_seq] * len(cut_peptides))
                HLA_names.extend([hla_name] * len(cut_peptides))
            else:
                peptides.append(pep)
                HLA_sequences.append(hla_seq)
                HLA_names.append(hla_name)
        else:
            peptides.append(pep)
            HLA_sequences.append(hla_seq)
            HLA_names.append(hla_name)
            
    else:
        cut_peptides = [pep[i : i + cut_length] for i in range(length - cut_length + 1)]
        peptides.extend(cut_peptides)
        HLA_sequences.extend([hla_seq] * len(cut_peptides))
        HLA_names.extend([hla_name] * len(cut_peptides))
        
predict_data = pd.DataFrame([HLA_names, HLA_sequences, peptides], index = ['HLA', 'HLA_sequence', 'peptide']).T
if predict_data.shape[0] == 0: 
    log = Logger('./error.log')
    log.logger.critical('No suitable data could be predicted. Please check your input data.')
    sys.exit(1)
    
if predict_data.shape[0] > 1000:
    args.output_heatmap = False
    args.output_mutation = False
    
    log = Logger('./error.log')
    log.logger.critical('Do not generate heatmap and mutation. Because the number of predict data > 50, and the output will be large.')
    sys.exit(1)
    
predict_data, predict_pep_inputs, predict_hla_inputs, predict_loader = read_predict_data(predict_data, batch_size)
predict_data

# Samples =  17


Unnamed: 0,HLA,HLA_sequence,peptide
0,HLA-A*11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,AEAFIQSA
1,HLA-A*11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,AEAFIQ
2,HLA-A*11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,AEAFIQPI
3,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,KVYEGVWKK
4,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,VYEGVWKKA
5,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,YEGVWKKAE
6,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,EGVWKKAEA
7,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,GVWKKAEAF
8,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,VWKKAEAFI
9,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,WKKAEAFIQ


# 预测

In [7]:
use_cuda = False
device = torch.device("cuda" if use_cuda else "cpu")

model_file = '../model/model_layer1_multihead9_fold4.pkl'

model_eval = Transformer().to(device)
model_eval.load_state_dict(torch.load(model_file), strict = True)

model_eval.eval()
y_pred, y_prob, attns = eval_step(model_eval, predict_loader, args.threshold, use_cuda)

predict_data['y_pred'], predict_data['y_prob'] = y_pred, y_prob
predict_data = predict_data.round({'y_prob': 4})

predict_data.to_csv(args.output_dir + '/predict_results.csv', index = False)
predict_data

Unnamed: 0,HLA,HLA_sequence,peptide,y_pred,y_prob
0,HLA-A*11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,AEAFIQSA,0,0.0001
1,HLA-A*11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,AEAFIQ,0,0.0008
2,HLA-A*11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,AEAFIQPI,1,0.5539
3,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,KVYEGVWKK,1,1.0
4,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,VYEGVWKKA,0,0.0
5,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,YEGVWKKAE,0,0.0
6,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,EGVWKKAEA,0,0.0
7,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,GVWKKAEAF,0,0.0
8,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,VWKKAEAFI,0,0.0
9,HLA-A*68:01,YYAMYRNNVAQTDVDTLYIMYRDYTWAVWAYTWY,WKKAEAFIQ,0,0.0


# 作图

In [8]:
if args.output_attention or args.output_heatmap:
    
    if args.output_attention: 
        attn_savepath = args.output_dir + '/attention/'
        if not os.path.exists(attn_savepath):
            os.makedirs(attn_savepath)
    else:
        attn_savepath = False
    if args.output_heatmap: 
        fig_savepath = args.output_dir + '/figures/'
        if not os.path.exists(fig_savepath):
            os.makedirs(fig_savepath)
    else:
        fig_savepath = False
        
    for hla, pep in zip(predict_data.HLA, predict_data.peptide):
        pHLA_attns_draw_save(predict_data, attns, hla, pep, attn_savepath, fig_savepath)

# 突变

In [9]:
from tqdm import trange
if args.output_mutation:
    mut_savepath = args.output_dir + '/mutation/'
    if not os.path.exists(mut_savepath):
        os.makedirs(mut_savepath, exist_ok=True)
    
    for idx in trange(len(predict_data)):
        peptide: str = predict_data.iloc[idx].peptide
        hla: str = predict_data.iloc[idx].HLA
        
        if len(peptide) < 8 or len(peptide) > 14: continue
            
        mut_peptides_df = pHLA_mutation_peptides(predict_data, attns, hla = hla, peptide = peptide)
        mut_data, _, _, mut_loader = read_predict_data(mut_peptides_df, batch_size)

        model_eval = Transformer().to(device)
        model_eval.load_state_dict(torch.load(model_file), strict = True)

        model_eval.eval()
        y_pred, y_prob, _ = eval_step(model_eval, mut_loader, args.threshold, use_cuda)

        mut_data['y_pred'], mut_data['y_prob'] = y_pred, y_prob
        mut_data = mut_data.round({'y_prob': 4})
        mut_data.to_csv(mut_savepath + '{}_{}_mutation.csv'.format(hla, peptide), index = False)
        print('********** {} | {} → # Mutation peptides = {}'.format(hla, peptide, mut_data.shape[0]-1))
        
        mut_peptides_IEDBfmt = ' '.join(mut_data.mutation_peptide)
        print('If you want to use IEDB tools to predict IC50, please use these format: \n {}'.format(mut_peptides_IEDBfmt))

  0%|          | 0/17 [00:00<?, ?it/s]

No HLA-A*11:01 with 8, Use the overall attention for pepAAtype-peppsition





KeyError: 'sum'

torch.Tensor

# 压缩文件

In [12]:
# mut_mb = 25/1024
# fig_mb = 2
# attn_mb = 10/1024
# (mut_mb + fig_mb + attn_mb)*2000 

In [2]:
from utils import make_zip
make_zip(source_dir = '/home/chujunyi/5_ZY_MHC/webserver', output_filename = '/home/chujunyi/5_ZY_MHC/webserver.zip')

In [10]:
# make_zip(source_dir = args.output_dir, output_filename = './results.zip')
# import shutil
# shutil.rmtree(args.output_dir)