<a href="https://colab.research.google.com/github/zhuzihan728/COMP0138-Metal-Binding-Site-Prediction/blob/main/colab_scripts/cluster.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import libs

In [None]:
!pip install biopython

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting biopython
  Downloading biopython-1.81-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biopython
Successfully installed biopython-1.81


In [None]:
import pandas as pd
from Bio import SeqIO
import json

# Extract datasets

In [None]:
!tar -xvf /content/drive/MyDrive/FYP/miniconda -C /root

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
miniconda/lib/python3.7/site-packages/conda_env/cli/__pycache__/main_config.cpython-37.pyc
miniconda/lib/python3.7/site-packages/conda_env/cli/main_config.py
miniconda/lib/python3.7/site-packages/conda_env/cli/main.py
miniconda/lib/python3.7/site-packages/conda_env/cli/main_vars.py
miniconda/lib/python3.7/site-packages/conda_env/exceptions.py
miniconda/lib/python3.7/site-packages/conda_env/installers/
miniconda/lib/python3.7/site-packages/conda_env/installers/__init__.py
miniconda/lib/python3.7/site-packages/conda_env/installers/conda.py
miniconda/lib/python3.7/site-packages/conda_env/installers/base.py
miniconda/lib/python3.7/site-packages/conda_env/installers/__pycache__/
miniconda/lib/python3.7/site-packages/conda_env/installers/__pycache__/pip.cpython-37.pyc
miniconda/lib/python3.7/site-packages/conda_env/installers/__pycache__/conda.cpython-37.pyc
miniconda/lib/python3.7/site-packages/conda_env/installers/__pycache__/__init__.cpython-37.pyc

In [None]:
!tar -xvf /content/drive/MyDrive/FYP/uniprot_datasets -C /content

ChEBI-IDs_for_metal_binding.tsv
NEG_clustered_rep_seq.fasta
NEG_TRAIN.fasta
POS_TRAIN.fasta
POS_TRAIN_FULL.fasta
POS_TRAIN_FULL.tsv
POS_TRAIN.tsv
filtered_combined.fasta
trimed_combined.fasta


In [None]:
!cat NEG_TRAIN.fasta	POS_TRAIN.fasta > combined.fasta

In [None]:
total_len = len(list(SeqIO.parse("combined.fasta", "fasta")))
print("Full data set size: ", total_len)

Full data set size:  195450


# Helper functions

In [None]:
def check_metal_per(seqs, metal, anno, metal_count_df):
  cnt = 0
  temp = anno.loc[anno['Accession'].isin(seqs)]
  temp1 = temp['ChEBI-ID'].value_counts().to_frame().reset_index()
  row = temp1[temp1['index'] == metal]['ChEBI-ID']
  cnt = 0 if len(row) == 0 else int(row)
  per = cnt / int(metal_count_df[metal_count_df['ChEBI-ID'] == metal]['count'])
  return per

In [None]:
def check_metal_num(seqs, metal, anno):
  cnt = 0
  temp = anno.loc[anno['Accession'].isin(seqs)]
  temp1 = temp['ChEBI-ID'].value_counts().to_frame().reset_index()
  row = temp1[temp1['index'] == metal]['ChEBI-ID']
  cnt = 0 if len(row) == 0 else int(row)
  return cnt

In [None]:
def check_metal_specific_residue_proportion(acc_ls, source = 'POS_TRAIN_FULL.tsv'):
  anno = pd.read_csv(source, sep='\t')
  metal_count_df = anno['ChEBI-ID'].value_counts().to_frame().reset_index()
  metal_count_df.columns = ['ChEBI-ID', 'count']
  metal_id_name_df = pd.read_csv('ChEBI-IDs_for_metal_binding.tsv', sep='\t')
  for metal in metal_count_df['ChEBI-ID'].unique():
    metal_name = metal_id_name_df[metal_id_name_df['ChEBI-ID']==metal]['Name'].iloc[0]
    num = check_metal_num(acc_ls, metal, anno) 
    total_num = int(metal_count_df[metal_count_df['ChEBI-ID'] == metal]['count'])
    print(f'{metal:12}| {metal_name:29} | num: {int(num):6} | %: {num/total_num}')

In [None]:
def write_seq_ls2fasta(file_out, ls, source):
  with open(file_out, 'w') as f_out:
    for seq_record in SeqIO.parse(source, "fasta"):
      seq_acc = seq_record.id.split('|')[1]
      if seq_acc in ls:
        r = SeqIO.write(seq_record, f_out, 'fasta')

        if r!=1: 
          print('Error while writing sequence: ' + seq_acc)
        else:
          print(f'writing {seq_acc} to train fasta file.')

In [None]:
def fasta2acc_seq_ls(path):
  acc = []
  seq = []

  for seq_record in SeqIO.parse(path, "fasta"):
    acc.append(seq_record.id.split('|')[1])
    seq.append(str(seq_record.seq))
  return acc, seq

In [None]:
def check_pos_neg_proportion(ls):
  total_num = len(ls)
  
  acc, _ = fasta2acc_seq_ls("POS_TRAIN_FULL.fasta")
  inter = set(acc).intersection(ls)
  pos_num = len(inter)
  neg_num = total_num - pos_num
  pos_portion = pos_num/total_num
  neg_portion = neg_num/total_num
  print(f'total seq in the set: {total_num}')
  print(f'proportion over full dataset: {total_num/total_len}')
  print(f'pos: {pos_num} %: {pos_portion}')
  print(f'neg: {neg_num} %: {neg_portion}')
  return total_num, pos_num, neg_num, pos_portion, neg_portion

In [None]:
def identity_above_threshold(m8file, thres):
  data = pd.read_csv(m8file, sep="\t", index_col=False, header=None)
  data.columns = ["query", "target","sequence identity","alignment length","mismatch","gap opening", "query domain start position", "end position","target domain start position", "end position", "evalue", "bit score"]
  
  seq_above_thres = data[data["sequence identity"] > thres]["query"].unique()
  seq_below_thres = data[~data["query"].isin(seq_above_thres)]["query"].unique()
  # print(data[data["sequence identity"] > thres]["sequence identity"].unique())
  all_seq = data["query"].unique()
  proportion = len(seq_above_thres) / len(all_seq)
  print(len(all_seq) == len(seq_above_thres) + len(seq_below_thres))
  return seq_above_thres, seq_below_thres, proportion

In [None]:
def read_fasta(fasta_path, split_char="|", id_field=1):
    '''
        Reads in fasta file containing multiple sequences.
        Split_char and id_field allow to control identifier extraction from header.
        E.g.: set split_char="|" and id_field=1 for SwissProt/UniProt Headers.
        Returns dictionary holding multiple sequences or only single 
        sequence, depending on input file.
    '''
    
    seqs = dict()
    with open( fasta_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                uniprot_id = line.replace('>', '').strip().split(split_char)[id_field]
                # replace tokens that are mis-interpreted when loading h5
                uniprot_id = uniprot_id.replace("/","_").replace(".","_")
                seqs[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines, drop gaps and cast to upper-case
                seq= ''.join( line.split() ).upper().replace("-","")
                # repl. all non-standard AAs and map them to unknown/X
                seq = seq.replace('U','X').replace('Z','X').replace('O','X')
                seqs[ uniprot_id ] += seq 
    example_id=next(iter(seqs))
    print("Read {} sequences.".format(len(seqs)))
    print("Example:\n{}\n{}".format(example_id,seqs[example_id]))

    return seqs

In [None]:
def dataset_metal_binding_summary(acc_ls, source = 'POS_TRAIN_FULL.tsv'):
  total_num = len(acc_ls)
  print(f'total seq in the set: {total_num}')

  all_pos_acc_ls, _ = fasta2acc_seq_ls("POS_TRAIN_FULL.fasta")
  metals = {'CHEBI:29105':0,'CHEBI:18420':1,'CHEBI:49883':2,'CHEBI:29108':3,'CHEBI:29035':4,'CHEBI:60240':5,'CHEBI:24875':6,'CHEBI:190135':7,'CHEBI:23378':8,'CHEBI:29103':9,'CHEBI:49786':10,'CHEBI:29101':11,'CHEBI:29034':12,'CHEBI:30408':13,'CHEBI:29036':14,'CHEBI:29033':15,'CHEBI:21137':16,'CHEBI:49552':17,'CHEBI:48775':18,'CHEBI:48828':19,'CHEBI:21143':20,'CHEBI:25213':21,'CHEBI:47739':22,'CHEBI:16793':23,'CHEBI:177874':24,'CHEBI:60400':25,'CHEBI:49415':26,'CHEBI:60504':27,'CHEBI:49713':28}
  anno = pd.read_csv(source, sep='\t')
  metal_count_df = anno['ChEBI-ID'].value_counts().to_frame().reset_index()
  metal_count_df.columns = ['ChEBI-ID', 'count']
  metal_id_name_df = pd.read_csv('ChEBI-IDs_for_metal_binding.tsv', sep='\t')
  prot_counter = [0]*29 
  res_counter = [0]*29
  pos_acc = set(all_pos_acc_ls).intersection(acc_ls)
  for i, metal in enumerate(metals):
    metal_name = metal_id_name_df[metal_id_name_df['ChEBI-ID']==metal]['Name'].iloc[0]
    temp = anno[anno['ChEBI-ID'] == metal]
    prot_counter[i] += len(temp[temp['Accession'].isin(pos_acc)]['Accession'].unique())
    res_counter[i] += check_metal_num(acc_ls, metal, anno)
    total_res_num = int(metal_count_df[metal_count_df['ChEBI-ID'] == metal]['count'])
    print(f"{metal:13}|{metal_name:30}|#p: {prot_counter[i]:10}|#residue: {res_counter[i]:6}|%residue/all: {res_counter[i]/total_res_num:{5}.{3}}")
  print(f"#non-binding protein: {total_num-len(pos_acc)}")
  return prot_counter, res_counter


In [None]:
def retrieve_json(path):
  with open(path, 'r') as fp:
    data = json.load(fp)
  return data

# Clustering by MMSEQS

In [None]:
%alias activate $HOME/miniconda/bin/activate

In [None]:
%alias mmseqs $HOME/miniconda/pkgs/mmseqs2-14.7e284-pl5321hf1761c0_0/bin/mmseqs

In [None]:
activate tutorial

In [None]:
mmseqs

MMseqs2 (Many against Many sequence searching) is an open-source software suite for very fast, 
parallelized protein sequence searches and clustering of huge protein sequence data sets.

Please cite: M. Steinegger and J. Soding. MMseqs2 enables sensitive protein sequence searching for the analysis of massive data sets. Nature Biotechnology, doi:10.1038/nbt.3988 (2017).

MMseqs2 Version: 14.7e284
© Martin Steinegger (martin.steinegger@snu.ac.kr)

usage: mmseqs <command> [<args>]

Easy workflows for plain text input/output
  easy-search       	Sensitive homology search
  easy-cluster      	Slower, sensitive clustering
  easy-linclust     	Fast linear time cluster, less sensitive clustering
  easy-taxonomy     	Taxonomic classification
  easy-rbh          	Find reciprocal best hit

Main workflows for database input/output
  search            	Sensitive homology search
  map               	Map nearly identical sequences
  rbh               	Reciprocal best hit search
  linclust          	F

In [None]:
mmseqs easy-cluster

usage: mmseqs easy-cluster <i:fastaFile1[.gz|.bz2]> ... <i:fastaFileN[.gz|.bz2]> <o:clusterPrefix> <tmpDir> [options]
options:                               
 -c FLOAT                       List matches above this fraction of aligned (covered) residues (see --cov-mode) [0.800]
 --cov-mode INT                 0: coverage of query and target
                                1: coverage of target
                                2: coverage of query
                                3: target seq. length has to be at least x% of query length
                                4: query seq. length has to be at least x% of target length
                                5: short seq. needs to be at least x% of the other seq. length [0]
 --alignment-mode INT           How to compute the alignment:
                                0: automatic
                                1: only score and end_pos
                                2: also start_pos and cov
                                3: also seq.i

In [None]:
mmseqs easy-cluster combined.fasta assembly_clustered tmp --cov-mode 5 -c 0.25 --min-seq-id 0.4 -s 7

Create directory tmp
easy-cluster combined.fasta assembly_clustered tmp --cov-mode 5 -c 0.25 --min-seq-id 0.4 -s 7 

MMseqs Version:                     	14.7e284
Substitution matrix                 	aa:blosum62.out,nucl:nucleotide.out
Seed substitution matrix            	aa:VTML80.out,nucl:nucleotide.out
Sensitivity                         	7
k-mer length                        	0
k-score                             	seq:2147483647,prof:2147483647
Alphabet size                       	aa:21,nucl:5
Max sequence length                 	65535
Max results per query               	20
Split database                      	0
Split mode                          	2
Split memory limit                  	0
Coverage threshold                  	0.25
Coverage mode                       	5
Compositional bias                  	1
Compositional bias                  	1
Diagonal scoring                    	true
Exact k-mer matching                	0
Mask residues                       	1
Mask residues prob

# Reading data

## reading dataset and annotations

acc: protein accessions in combined.fasta

In [None]:
acc, _ = fasta2acc_seq_ls("combined.fasta")

anno: the annotation dataframe POS_TRAIN_FULL.tsv

In [None]:
anno = pd.read_csv('POS_TRAIN.tsv', sep='\t')
anno

Unnamed: 0,Accession,Evidence,ChEBI-ID,Position
0,Q8INK9,ECO:0000269,CHEBI:29105,157
1,Q8INK9,ECO:0000269,CHEBI:29105,96
2,Q1QT89,ECO:0000269,CHEBI:18420,263
3,P07327,ECO:0000269,CHEBI:29105,101
4,P07327,ECO:0007744,CHEBI:29105,104
...,...,...,...,...
18038,P62339,ECO:0000269,CHEBI:60240,43
18039,P62339,ECO:0000269,CHEBI:60240,23
18040,P62339,ECO:0007744,CHEBI:60240,32
18041,P62339,ECO:0000269,CHEBI:60240,40


metal_count_df: metal and the number of residues binding that metal from POS_TRAIN_FULL

In [None]:
metal_count_df = anno['ChEBI-ID'].value_counts().to_frame().reset_index()
metal_count_df.columns = ['ChEBI-ID', 'count']
metal_count_df

Unnamed: 0,ChEBI-ID,count
0,CHEBI:29105,5788
1,CHEBI:29108,4768
2,CHEBI:18420,2140
3,CHEBI:29035,1136
4,CHEBI:24875,919
5,CHEBI:49883,885
6,CHEBI:60240,738
7,CHEBI:23378,561
8,CHEBI:190135,348
9,CHEBI:29103,163


In [None]:
class_enc = retrieve_json('/content/drive/MyDrive/FYP/dicts/class_encode.json')

In [None]:
temp_cnt = []
for i in metal_count_df['ChEBI-ID']:
  temp_cnt.append(check_metal_num(acc, i, anno))
metal_count_df = pd.DataFrame({'ChEBI-ID': metal_count_df['ChEBI-ID'], 'count': temp_cnt})
metal_count_df

Unnamed: 0,ChEBI-ID,count
0,CHEBI:29105,5788
1,CHEBI:29108,4768
2,CHEBI:18420,2140
3,CHEBI:29035,1136
4,CHEBI:24875,919
5,CHEBI:49883,885
6,CHEBI:60240,738
7,CHEBI:23378,561
8,CHEBI:190135,348
9,CHEBI:29103,163


## retrive cluster results

In [None]:
print('number of prot seqs:', len(list(SeqIO.parse("combined.fasta", "fasta"))))
# 262004

number of prot seqs: 195450


In [None]:
print('number of clusters:', len(list(SeqIO.parse("assembly_clustered_rep_seq.fasta", "fasta"))))
# 38717

number of clusters: 33071


clusters: dataframe of [rep accession, a protein accession in the rep's cluster]


In [None]:
clusters = pd.read_csv('assembly_clustered_cluster.tsv', sep='\t', header=None)
clusters.columns = ['Rep', 'Accession']
clusters

Unnamed: 0,Rep,Accession
0,A0A0K0IP23,A0A0K0IP23
1,A0A0K0IP23,P22085
2,A5GNU1,A5GNU1
3,A5GNU1,Q7V5D4
4,A5GNU1,Q0I762
...,...,...
195445,Q94252,Q93789
195446,Q94252,O16956
195447,Q99L85,Q99L85
195448,Q99L85,Q6IUP3


In [None]:
print('number of reps: %d' % len(clusters['Rep'].unique())) # check if #rep == #cluster

number of reps: 33071


In [None]:
print('number of clustered seqs: %d' % len(clusters)) # check if #clustered seqs == #seqs

number of clustered seqs: 195450


In [None]:
33071/195450

0.16920440010232796

## Extract metal binding information of the clusters

cluster_label: merging cluster with annotations, **drop clusters where there is no metal-binding sequence at all**. \
dataframe of [rep acc, protein acc, evidence of the protein, metal binding to the protein, position of binding site]

In [None]:
cluster_label = pd.merge(clusters, anno, on = 'Accession', how = "outer")
cluster_label.dropna(axis=0, how='any', inplace=True)
cluster_label

Unnamed: 0,Rep,Accession,Evidence,ChEBI-ID,Position
1073,B1XI84,P80373,ECO:0000269,CHEBI:29105,31.0
1074,B1XI84,P80373,ECO:0000269,CHEBI:29105,12.0
1075,B1XI84,P80373,ECO:0000269,CHEBI:29105,9.0
1076,B1XI84,P80373,ECO:0000269,CHEBI:29105,26.0
1475,P12453,P04608,ECO:0000269,CHEBI:29105,37.0
...,...,...,...,...,...
209731,Q2UNX8,Q4WW81,ECO:0007744,CHEBI:29105,252.0
209732,Q2UNX8,Q4WW81,ECO:0000269,CHEBI:29105,244.0
209733,Q2UNX8,Q4WW81,ECO:0007744,CHEBI:29105,244.0
209734,Q2UNX8,Q4WW81,ECO:0000269,CHEBI:29105,252.0


## Record the number of binding residues with different metals for clusters and sequences

For each cluster (represented by the rep prot accession) where there are metal-binding sequences, count the number of binding residues with each metal. \

rep_metal: dataframe of [cluster rep, metal type, number of binding residues with the metal in the cluster] \

Used as a record table for future data splitting.

In [None]:
rep_metal = cluster_label[['Rep', 'ChEBI-ID']].value_counts().to_frame().reset_index()
rep_metal.columns = ['Rep', 'ChEBI-ID', 'count']
rep_metal

Unnamed: 0,Rep,ChEBI-ID,count
0,P04355,CHEBI:60240,204
1,O31527,CHEBI:29108,164
2,P63098,CHEBI:29108,160
3,P0DP23,CHEBI:29108,153
4,Q9NQV7,CHEBI:29105,112
...,...,...,...
1792,P9WLU3,CHEBI:29035,1
1793,Q8NK92,CHEBI:29108,1
1794,P05987,CHEBI:18420,1
1795,P9WLU3,CHEBI:18420,1


rep_Allmetal: a temporary dataframe of [rep of a cluster, number of binding residues in the cluster regardless of metal type]

reps: a list containing the accessions of all reps whose cluster contains metal-binding sequences, shuffle it for future data splitting.

In [None]:
rep_Allmetal = cluster_label['Rep'].value_counts().to_frame().reset_index()
rep_Allmetal.columns = ['Rep', 'count']
reps = list(rep_Allmetal['Rep'].unique())
import random
random.seed(42)
random.shuffle(reps)
reps[:10]

['O95486',
 'O60494',
 'O64332',
 'P0AES2',
 'O13833',
 'P69380',
 'Q3YW59',
 'O61142',
 'Q94GM9',
 'Q96PN6']

For each metal-binding sequence, count the number of binding residues with each metal. \

seq_metal: dataframe of [prot sequence, metal type, number of binding residues with the metal] \

Used as a record table for future data splitting.

In [None]:
seq_metal = cluster_label[['Accession', 'ChEBI-ID']].value_counts().to_frame().reset_index()
seq_metal.columns = ['Accession', 'ChEBI-ID', 'count']
seq_metal

Unnamed: 0,Accession,ChEBI-ID,count
0,O31526,CHEBI:29108,88
1,O31527,CHEBI:29108,76
2,E0VIU9,CHEBI:29105,64
3,Q51817,CHEBI:29108,56
4,O75592,CHEBI:29105,48
...,...,...,...
2547,P15848,CHEBI:29108,1
2548,P0A9G6,CHEBI:18420,1
2549,P00971,CHEBI:18420,1
2550,A5W059,CHEBI:18420,1


seq_Allmetal: a temporary dataframe of [a prot seq, number of binding residues to the seq regardless of metal type]

seqs: a list containing the accessions of all metal-binding protein seqs in POS_TRAIN_FULL.fasta, shuffle it for future data splitting.

In [None]:
seq_Allmetal = cluster_label['Accession'].value_counts().to_frame().reset_index()
seq_Allmetal.columns = ['Accession', 'count']
seqs = list(seq_Allmetal['Accession'].unique())
random.shuffle(seqs)
seqs[:10]

['Q5SK67',
 'E9AE57',
 'Q29437',
 'Q9V099',
 'O43813',
 'Q8GXV5',
 'P23532',
 'Q9X0H1',
 'P51688',
 'P46976']

In [None]:
print('number of clusters containing metalloprotein: %d' % len(reps))

number of clusters containing metalloprotein: 1588


non_metal_reps: all reps whose cluster contains no metal-binding sequences at all.

In [None]:
non_metal_reps = []
for i in clusters['Rep'].unique():
  if i not in reps:
    non_metal_reps.append(i)
print('number of clusters not containing metalloprotein: %d' % len(non_metal_reps))

number of clusters not containing metalloprotein: 31483


# Data splitting

- Goal: split a test dataset from the whole dataset with a test:trainval ratio 1:9, making sure that in the test set, the number of metal-binding residues for each metal type is roughly 1:9 of the trainval set, this also ensures that metal-binding residues for each metal type appears at least once in the test set. \


- procedure
  1. establish a table (metal_count_df) from the annotation POS_TRAIN_FULL.fasta. The table records every metal and the number of residues binding to the metal.
  2. establish a table (metal_count_test) from the annotation POS_TRAIN_FULL.fasta. The table records every metal and the number of residues binding to the metal * 0.1, rounded.
  3. iterate random shuffled clusters containing at least one metal-binding protein, for each cluster, decide if it should go to trainval or test set by: \
    a. for each metal type
      - retrive the number of residues binding to it in that cluster. 
      - compare the number to the record in metal_count_test, if the number is greater than the record (+10 flexibility), append to trainval set.
      - compare the number to the record in metal_count_df, if the number equals the record, mearning all residues binding to the metal belong to this cluster, append to trainval set. 

    b. If the cluster is not appended to trainvalset in a., append to test set, update metal_count_test by subtracting the numbers of newly added residues of each metal type from the corresponding records.
  4. For clusters containing no metal-binding protein at all, randomly assign them to trainval or test set with probs 0.9, 0.1.
  5. replace clusters in the trainval and test set by proteins in the clusters.
  6. For each metal, check the #binding residues in test set/#binding residues in full set, if the proportion is smaller than 0.99, randomly grab to test set a protein from trainval set that binds the metal, this is done iteratively until the proportion hits 0.99.

## cluster splitting

In [None]:
import numpy as np

In [None]:
import math

metal_count_test = metal_count_df.copy()
metal_count_test['count'] = metal_count_test['count']*0.9//1+1

metal_test = set()
test = []
trainval = []
for i in reps:
  print(f'----- decide for cluster {i} -----')

  temp = rep_metal[rep_metal['Rep'] == i].reset_index()
  print(f'cluster {i} has annotations:\n{temp}')
  flag = 1
  for m in temp['ChEBI-ID']:
    num = temp[temp['ChEBI-ID']==m].iloc[0,3]
    cur = metal_count_test.loc[metal_count_test['ChEBI-ID']==m].iloc[0,1]
    max_num = metal_count_df.loc[metal_count_df['ChEBI-ID']==m].iloc[0,1]
    if cur + 5 < num or num == max_num:
      print(f'metal {m} exceed maximum, cluster {i} -> test')
      test.append(i)
      flag = 0
      break
  if flag:
    trainval.append(i)
    print(f'cluster {i} -> trainval')
    for n in temp['ChEBI-ID']: 
        print(f"the cluster has metal {n}, number {temp[temp['ChEBI-ID']==n].iloc[0,3]}")
        print(metal_count_test)
        metal_count_test.loc[metal_count_test['ChEBI-ID']==n, 'count'] -= temp[temp['ChEBI-ID']==n].iloc[0,3]
        print(metal_count_test)
        metal_test.add(n)
        print(f"update metal presented in trainval set: {metal_test}")
  print('\n\n')

randnums = np.random.randint(0, len(non_metal_reps), int(0.1*len(non_metal_reps)))
for i, v in enumerate(non_metal_reps):
  if i in randnums:
    test.append(v)
  else:
    trainval.append(v)
print(test)
print(trainval)



[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
1    CHEBI:29108  149.0
2    CHEBI:18420   55.0
3    CHEBI:29035    3.0
4    CHEBI:24875    1.0
5    CHEBI:49883   17.0
6    CHEBI:60240   92.0
7    CHEBI:23378   91.0
8   CHEBI:190135   16.0
9    CHEBI:29103   18.0
10   CHEBI:29101   -3.0
11   CHEBI:29034   -3.0
12   CHEBI:49786   -3.0
13   CHEBI:48828    9.0
14   CHEBI:48775   -4.0
15   CHEBI:21137   -3.0
16   CHEBI:29036    5.0
17   CHEBI:49552   -2.0
18   CHEBI:29033    1.0
19   CHEBI:47739    5.0
20  CHEBI:177874    6.0
21   CHEBI:49415    4.0
22   CHEBI:16793    2.0
        ChEBI-ID  count
0    CHEBI:29105   64.0
1    CHEBI:29108  149.0
2    CHEBI:18420   55.0
3    CHEBI:29035    3.0
4    CHEBI:24875    1.0
5    CHEBI:49883   17.0
6    CHEBI:60240   92.0
7    CHEBI:23378   91.0
8   CHEBI:190135   16.0
9    CHEBI:29103   18.0
10   CHEBI:29101   -3.0
11   CHEBI:29034   -3.0
12   CHEBI:49786   -3.0
13   CHEBI:48828    9.0
14   CHEBI:48775   -4.0
15   CHEBI:21137   -3.0
16   CHEBI:29036    5.0

In [None]:
print(len(test))

3148


In [None]:
print(len(trainval))

29923


In [None]:
list(set(test) & set(trainval)) # check no intersection

[]

## Helper function 
`check_Allmetal` counts, for each metal type, how many residues are bond with that metal in test set, and the proportion: #residues bond by the metal in test set/#residues bond by the metal in full set.

In [None]:
# for i in cluster_label[cluster_label['Rep'] == 'Q8HUH0']['ChEBI-ID']:
#       print(i)
def check_Allmetal(reps):
  stat = {}
  metals = metal_count_df['ChEBI-ID'].unique()
  metal_dic = dict(zip(metals, [0 for i in range(len(metals))]))
  for rep in reps:
    for i, row in rep_metal[rep_metal['Rep'] == rep].iterrows():
      metal_dic[row['ChEBI-ID']] += row['count']
  for i in metal_dic:
    per = metal_dic[i] / int(metal_count_df[metal_count_df['ChEBI-ID'] == i]['count'])
    stat[i] = [metal_dic[i], per]
    print(f'metal: {i}, num: {metal_dic[i]}, percentage: {per}')
  return stat

def check_metal(seqs, metal):
  cnt = 0
  temp = cluster_label.loc[cluster_label['Accession'].isin(seqs)]
  temp1 = temp['ChEBI-ID'].value_counts().to_frame().reset_index()
  row = temp1[temp1['index'] == metal]['ChEBI-ID']
  cnt = 0 if len(row) == 0 else int(row)
  per = cnt / int(metal_count_df[metal_count_df['ChEBI-ID'] == metal]['count'])
  return per


In [None]:
init_stat = check_Allmetal(test)

metal: CHEBI:29105, num: 569, percentage: 0.09830684174153421
metal: CHEBI:29108, num: 468, percentage: 0.09815436241610738
metal: CHEBI:18420, num: 203, percentage: 0.09485981308411215
metal: CHEBI:29035, num: 103, percentage: 0.09066901408450705
metal: CHEBI:24875, num: 83, percentage: 0.09031556039173014
metal: CHEBI:49883, num: 78, percentage: 0.08813559322033898
metal: CHEBI:60240, num: 64, percentage: 0.08672086720867209
metal: CHEBI:23378, num: 50, percentage: 0.08912655971479501
metal: CHEBI:190135, num: 38, percentage: 0.10919540229885058
metal: CHEBI:29103, num: 18, percentage: 0.11042944785276074
metal: CHEBI:29101, num: 8, percentage: 0.056338028169014086
metal: CHEBI:29034, num: 8, percentage: 0.058823529411764705
metal: CHEBI:49786, num: 0, percentage: 0.0
metal: CHEBI:48828, num: 10, percentage: 0.2
metal: CHEBI:48775, num: 0, percentage: 0.0
metal: CHEBI:21137, num: 0, percentage: 0.0
metal: CHEBI:29036, num: 8, percentage: 0.24242424242424243
metal: CHEBI:49552, num: 0

## replace clusters by proteins in trainval and test set

In [None]:
test_seqs = []
trainval_seqs = []
test_seqs = list(clusters.loc[clusters['Rep'].isin(test)]['Accession'])
trainval_seqs = list(clusters.loc[clusters['Rep'].isin(trainval)]['Accession'])
# for i in test:
#   for ind, seq in clusters[clusters['Rep'] == i].iterrows():
#     test_seqs.append(seq['Accession'])

# for i in trainval:
#   for ind, seq in clusters[clusters['Rep'] == i].iterrows():
#     trainval_seqs.append(seq['Accession'])
print(test_seqs)
print(trainval_seqs)


['C0HKX3', 'Q92374', 'Q9CS74', 'O95905', 'A0A2H3CSB7', 'P35220', 'Q6GLP0', 'A4IGI7', 'B7ZC77', 'P26232', 'P30997', 'Q5R416', 'Q61301', 'Q3MHM6', 'Q59I72', 'P35221', 'P26231', 'Q9UI47', 'Q65CL1', 'P90947', 'Q5AL27', 'Q6BP80', 'B7M7P4', 'C4ZXB6', 'B6I5B4', 'B1XB16', 'Q83QJ8', 'B7NRJ0', 'B7N6C8', 'Q31XU9', 'Q8XA72', 'B5Z114', 'B1IVT6', 'P0CI31', 'Q3YZ12', 'B1LNJ7', 'A8A347', 'P0CI32', 'Q0T1X8', 'B7LDD3', 'A7ZPY4', 'P47227', 'P50206', 'P08694', 'P72220', 'Q46381', 'Q9WXG7', 'Q7N4V7', 'P08088', 'P47230', 'O04547', 'Q50036', 'P9WHF9', 'A0R4C9', 'Q7TX80', 'Q72JV2', 'D4GYM0', 'P13597', 'Q00238', 'Q95132', 'Q9UMF0', 'Q60625', 'Q28730', 'Q5NKU6', 'P32942', 'Q28125', 'Q5NKV4', 'Q5NKV6', 'Q5NKV9', 'Q28806', 'P05362', 'P33729', 'Q38042', 'A0A1U9YI02', 'Q0CS62', 'Q4WMJ5', 'Q2UPB3', 'Q2UPA6', 'A5GTL6', 'Q8DLG1', 'Q5N128', 'Q31KU1', 'B0CAE7', 'B1WUV7', 'B0JRV0', 'B7K970', 'B7JZG7', 'P73554', 'B1XL12', 'Q8YZT2', 'Q3M9A3', 'A9KNV9', 'C4Z030', 'Q08334', 'Q61190', 'Q13428', 'O08784', 'Q58997', 'P52986', '

proteins in test set/#proteins in full set, should be roughly 0.1.

In [None]:
cnt = len(clusters.loc[clusters['Accession'].isin(test_seqs)])

cnt / len(clusters)

0.07944742900997698

In [None]:
list(set(test_seqs) & set(trainval_seqs)) # check no intersection

[]

In [None]:
# check if the two sets contain all data
len(test_seqs) + len(trainval_seqs)

177794

In [None]:
[i for i in trainval_seqs if i in test_]

[]

# refine the test set
make sure every metal label appears in the test set.

In [None]:
new_test_seqs = test_seqs.copy()
new_trainval_seqs = trainval_seqs.copy()


for i, v in init_stat.items():
  print(f'check for {i}')
  print(f'initial number is {v[0]}, percentage is {v[1]}')
  if v[1] == 0 and v[0]:
    for seq in seqs:
      if seq in new_trainval_seqs:
        if ((seq_metal['Accession'] == seq) & (seq_metal['ChEBI-ID'] == i)).any():
          new_trainval_seqs.remove(seq)
          new_test_seqs.append(seq)
          print(f'{seq} -> test')
          new_per = check_metal(new_test_seqs, i)
          print(f'updated percentage is {new_per}')
          break

check for CHEBI:29105
initial number is 573, percentage is 0.09899792674498964
check for CHEBI:29108
initial number is 471, percentage is 0.09878355704697987
check for CHEBI:18420
initial number is 208, percentage is 0.09719626168224299
check for CHEBI:29035
initial number is 108, percentage is 0.09507042253521127
check for CHEBI:24875
initial number is 88, percentage is 0.0957562568008705
check for CHEBI:49883
initial number is 84, percentage is 0.09491525423728814
check for CHEBI:60240
initial number is 68, percentage is 0.0921409214092141
check for CHEBI:23378
initial number is 60, percentage is 0.10695187165775401
check for CHEBI:190135
initial number is 29, percentage is 0.08333333333333333
check for CHEBI:29103
initial number is 24, percentage is 0.147239263803681
check for CHEBI:29101
initial number is 22, percentage is 0.15492957746478872
check for CHEBI:29034
initial number is 20, percentage is 0.14705882352941177
check for CHEBI:49786
initial number is 18, percentage is 0.233

In [None]:
list(set(new_test_seqs) & set(new_trainval_seqs)) # check no intersection

[]

## Final proportions

In [None]:
dataset_metal_binding_summary(trainval_seqs, source = 'POS_TRAIN_FULL.tsv')
pass

total seq in the set: 162266
CHEBI:29105  |Zn(2+)                        |#p:        630|#residue:   5777|%residue/all: 0.0432
CHEBI:18420  |Mg(2+)                        |#p:        454|#residue:   2160|%residue/all: 0.0242
CHEBI:49883  |[4Fe-4S] cluster              |#p:         96|#residue:    832|%residue/all: 0.0161
CHEBI:29108  |Ca(2+)                        |#p:        346|#residue:   4647|%residue/all:  0.11
CHEBI:29035  |Mn(2+)                        |#p:        164|#residue:   1072|%residue/all: 0.0478
CHEBI:60240  |a divalent metal cation       |#p:         74|#residue:    664|%residue/all: 0.0365
CHEBI:24875  |Fe cation                     |#p:        149|#residue:    945|%residue/all: 0.0542
CHEBI:190135 |[2Fe-2S] cluster              |#p:         47|#residue:    327|%residue/all: 0.0362
CHEBI:23378  |Cu cation                     |#p:         66|#residue:    434|%residue/all: 0.0587
CHEBI:29103  |K(+)                          |#p:         21|#residue:    155|%residue/all:

In [None]:
dataset_metal_binding_summary(test_seqs, source = 'POS_TRAIN_FULL.tsv')
pass

total seq in the set: 15528
CHEBI:29105  |Zn(2+)                        |#p:         80|#residue:    680|%residue/all: 0.00508
CHEBI:18420  |Mg(2+)                        |#p:         54|#residue:    263|%residue/all: 0.00294
CHEBI:49883  |[4Fe-4S] cluster              |#p:         14|#residue:    102|%residue/all: 0.00198
CHEBI:29108  |Ca(2+)                        |#p:         53|#residue:    522|%residue/all: 0.0124
CHEBI:29035  |Mn(2+)                        |#p:         20|#residue:    119|%residue/all: 0.0053
CHEBI:60240  |a divalent metal cation       |#p:          5|#residue:     92|%residue/all: 0.00506
CHEBI:24875  |Fe cation                     |#p:         14|#residue:    103|%residue/all: 0.0059
CHEBI:190135 |[2Fe-2S] cluster              |#p:          6|#residue:     54|%residue/all: 0.00598
CHEBI:23378  |Cu cation                     |#p:         11|#residue:    129|%residue/all: 0.0174
CHEBI:29103  |K(+)                          |#p:          1|#residue:     12|%residue

In [None]:
cnt = len(clusters.loc[clusters['Accession'].isin(new_test_seqs)])

cnt / len(clusters)

NameError: ignored

In [None]:
len(new_test_seqs)

16838

In [None]:
len(new_trainval_seqs)

178612

# Write trainval and test set in fasta format
TEST_POS_NEG.fasta: test proteins \
MY_TRAIN_POS_NEG.fasta: trainval proteins

In [None]:
dataset_metal_binding_summary(new_test_seqs, source = 'POS_TRAIN_FULL.tsv')
pass

total seq in the set: 16838
CHEBI:29105  |Zn(2+)                        |#p:         86|#residue:    738|%residue/all: 0.00552
CHEBI:18420  |Mg(2+)                        |#p:         68|#residue:    286|%residue/all: 0.0032
CHEBI:49883  |[4Fe-4S] cluster              |#p:         16|#residue:    110|%residue/all: 0.00213
CHEBI:29108  |Ca(2+)                        |#p:         31|#residue:    575|%residue/all: 0.0136
CHEBI:29035  |Mn(2+)                        |#p:         22|#residue:    140|%residue/all: 0.00624
CHEBI:60240  |a divalent metal cation       |#p:         12|#residue:    102|%residue/all: 0.00561
CHEBI:24875  |Fe cation                     |#p:         17|#residue:    115|%residue/all: 0.00659
CHEBI:190135 |[2Fe-2S] cluster              |#p:          8|#residue:     51|%residue/all: 0.00564
CHEBI:23378  |Cu cation                     |#p:          9|#residue:     78|%residue/all: 0.0105
CHEBI:29103  |K(+)                          |#p:          3|#residue:     24|%residu

In [None]:
file_out = 'TEST_POS_NEG1.fasta'
write_seq_ls2fasta(file_out, new_test_seqs, 'combined.fasta')

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
writing Q5UQ35 to train fasta file.
writing Q5UQC3 to train fasta file.
writing Q5UQG2 to train fasta file.
writing Q5UQW2 to train fasta file.
writing Q5UR69 to train fasta file.
writing Q5UX65 to train fasta file.
writing Q5UYF3 to train fasta file.
writing Q5UYQ8 to train fasta file.
writing Q5UZ63 to train fasta file.
writing Q5UZW0 to train fasta file.
writing Q5V0S8 to train fasta file.
writing Q5V1N9 to train fasta file.
writing Q5V2D3 to train fasta file.
writing Q5V2S1 to train fasta file.
writing Q5V3P6 to train fasta file.
writing Q5V474 to train fasta file.
writing Q5V4R4 to train fasta file.
writing Q5V518 to train fasta file.
writing Q5V5G2 to train fasta file.
writing Q5VVJ2 to train fasta file.
writing Q5W283 to train fasta file.
writing Q5WBJ6 to train fasta file.
writing Q5WCF2 to train fasta file.
writing Q5WDF8 to train fasta file.
writing Q5WDH1 to train fasta file.
writing Q5WDX3 to train fasta file.
writing Q5WDZ8 to train

In [None]:
dataset_metal_binding_summary(new_trainval_seqs, source = 'POS_TRAIN_FULL.tsv')
pass

total seq in the set: 246025
CHEBI:29105  |Zn(2+)                        |#p:      22505|#residue: 106903|%residue/all: 0.799
CHEBI:18420  |Mg(2+)                        |#p:      26076|#residue:  71517|%residue/all:   0.8
CHEBI:49883  |[4Fe-4S] cluster              |#p:       8541|#residue:  41221|%residue/all: 0.799
CHEBI:29108  |Ca(2+)                        |#p:       4584|#residue:  33716|%residue/all: 0.799
CHEBI:29035  |Mn(2+)                        |#p:       4002|#residue:  17914|%residue/all: 0.798
CHEBI:60240  |a divalent metal cation       |#p:       3582|#residue:  14543|%residue/all: 0.799
CHEBI:24875  |Fe cation                     |#p:       3965|#residue:  13948|%residue/all: 0.799
CHEBI:190135 |[2Fe-2S] cluster              |#p:       2218|#residue:   7218|%residue/all: 0.799
CHEBI:23378  |Cu cation                     |#p:       1160|#residue:   5905|%residue/all: 0.799
CHEBI:29103  |K(+)                          |#p:       1358|#residue:   5002|%residue/all: 0.798
C

In [None]:
file_out = 'TRAIN_POS_NEG1.fasta'
write_seq_ls2fasta(file_out, new_trainval_seqs, 'combined.fasta')

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
writing Q97QS2 to train fasta file.
writing Q97QT2 to train fasta file.
writing Q97QW0 to train fasta file.
writing Q97QW8 to train fasta file.
writing Q97R22 to train fasta file.
writing Q97R24 to train fasta file.
writing Q97R31 to train fasta file.
writing Q97R46 to train fasta file.
writing Q97R66 to train fasta file.
writing Q97RC6 to train fasta file.
writing Q97RI6 to train fasta file.
writing Q97RS9 to train fasta file.
writing Q97RW2 to train fasta file.
writing Q97S25 to train fasta file.
writing Q97S28 to train fasta file.
writing Q97S34 to train fasta file.
writing Q97S93 to train fasta file.
writing Q97SC7 to train fasta file.
writing Q97SD7 to train fasta file.
writing Q97SR2 to train fasta file.
writing Q97SR4 to train fasta file.
writing Q97T27 to train fasta file.
writing Q97T80 to train fasta file.
writing Q97T98 to train fasta file.
writing Q97TX9 to train fasta file.
writing Q97TZ9 to train fasta file.
writing Q97U21 to train

In [None]:
!tar cvf full_data_split *1.fasta

TEST_POS_NEG1.fasta
TRAIN_POS_NEG1.fasta


# Filter negative data

In [None]:
print('number of clusters:', len(list(SeqIO.parse("assembly_clustered_rep_seq.fasta", "fasta"))))

number of clusters: 33071


In [None]:
acc_train, _ = fasta2acc_seq_ls("POS_TRAIN.fasta")

In [None]:
neg_filtered = []

In [None]:
for i in reps:
  for j in clusters[clusters['Rep'] == i]['Accession']:
    neg_filtered.append(j)
for i in non_metal_reps:
  neg_filtered.append(i)

In [None]:
len(neg_filtered)

40368

In [None]:
with open('/content/filtered_seqs.txt', 'w') as f:
  for i in neg_filtered:
    f.write(i + '\n')

In [None]:
trainval = []
with open('/content/trainval.txt', 'r') as f:
  for i in f.readlines():
    trainval.append(i[:-1])

In [None]:
trainval_filtered = set(neg_filtered).intersection(trainval)

In [None]:
len(trainval)

177794

In [None]:
len(trainval_filtered)

36457

In [None]:
test = []
with open('/content/test.txt', 'r') as f:
  for i in f.readlines():
    test.append(i[:-1])

In [None]:
test_filtered = set(neg_filtered).intersection(test)

In [None]:
len(test)

17656

In [None]:
len(test_filtered)

3911

In [None]:
with open('/content/trainval_filtered40.txt', 'w') as f:
  for i in trainval_filtered:
    f.write(i + '\n')

In [None]:
with open('/content/test_filtered40.txt', 'w') as f:
  for i in test_filtered:
    f.write(i + '\n')