# Overlap
Some sequences from the training data can significantly overlap with data from the kinases dataset. Let's fix this data leakage.

## Prepare data for mmseqs2
Extract sequences from the training data and from the kinases data

In [2]:
# Read the training data sequences and input them into a .fasta file
import csv 
with open('/home/vit/Projects/LBS-pLM/data/full-LIGYSIS/full-train.txt', 'r') as f:
    reader = csv.reader(f, delimiter=';')
    with open('/home/vit/Projects/LBS-pLM/data/filtering/sequences.fasta', 'w') as fasta_file:
        for row in reader:
            seq_id = row[0]
            sequence = row[4]
            fasta_file.write(f'>{seq_id}-TRAIN\n{sequence}\n')

### REMARK
We manually removed:
```
6M11,A
6M11,B
6M12,A
6M12,B
6M13,A
6M13,B
```
from `kinase_pdb_chain_list.csv`.

In [10]:
# Read each PDB ID, take the sequence from the MMCIF file and input it into a .fasta file
CIF_FILES_PATH = '/home/vit/Projects/deeplife-project/data/cif_files'

mapping = {'Aba': 'A', 'Ace': 'X', 'Acr': 'X', 'Ala': 'A', 'Aly': 'K', 'Arg': 'R', 'Asn': 'N', 'Asp': 'D', 'Cas': 'C',
           'Ccs': 'C', 'Cme': 'C', 'Csd': 'C', 'Cso': 'C', 'Csx': 'C', 'Cys': 'C', 'Dal': 'A', 'Dbb': 'T', 'Dbu': 'T',
           'Dha': 'S', 'Gln': 'Q', 'Glu': 'E', 'Gly': 'G', 'Glz': 'G', 'His': 'H', 'Hse': 'S', 'Ile': 'I', 'Leu': 'L',
           'Llp': 'K', 'Lys': 'K', 'Men': 'N', 'Met': 'M', 'Mly': 'K', 'Mse': 'M', 'Nh2': 'X', 'Nle': 'L', 'Ocs': 'C',
           'Pca': 'E', 'Phe': 'F', 'Pro': 'P', 'Ptr': 'Y', 'Sep': 'S', 'Ser': 'S', 'Thr': 'T', 'Tih': 'A', 'Tpo': 'T',
           'Trp': 'W', 'Tyr': 'Y', 'Unk': 'X', 'Val': 'V', 'Ycm': 'C', 'Sec': 'U', 'Pyl': 'O', 'Mhs': 'H', 'Snm': 'S',
           'Mis': 'S', 'Seb': 'S', 'Hic': 'H', 'Fme': 'M', 'Asb': 'D', 'Sah': 'C', 'Smc': 'C', 'Tpq': 'Y', 'Onl': 'X',
           'Tox': 'W', '5x8': 'X', 'Ddz': 'A'}


def three_to_one(three_letter_code):
    if three_letter_code[0].upper() + three_letter_code[1:].lower() not in mapping:
        return 'X'
    return mapping[three_letter_code[0].upper() + three_letter_code[1:].lower()]

def get_sequence(pdb_id: str, chain_id: str) -> str:
    """
    Get the amino acid sequence of a specific chain from a PDB structure.
    """
    import biotite.database.rcsb as rcsb
    import biotite.structure.io.pdbx as pdbx
    from biotite.structure.io.pdbx import get_structure
    from biotite.structure import get_residues

    cif_file_path = rcsb.fetch(pdb_id, "cif", CIF_FILES_PATH)
    cif_file = pdbx.CIFFile.read(cif_file_path)
    
    protein = get_structure(cif_file, model=1)
    protein = protein[(protein.atom_name == "CA") 
                        & (protein.element == "C") 
                        & (protein.chain_id == chain_id) ]
    _, residue_types = get_residues(protein)

    sequence = ''
    for i in range(len(residue_types)):
        amino_acid = three_to_one(residue_types[i])

        sequence += amino_acid

    return sequence

skip = True
with open('/home/vit/Projects/LBS-pLM/data/kinase_pdb_chain_list.csv', 'r') as f:
    reader = csv.reader(f, delimiter=',')
    next(reader)  # Skip header
    with open('/home/vit/Projects/LBS-pLM/data/filtering/sequences.fasta', 'a') as fasta_file:
        for row in reader:
            pdb_id = row[0]
            chain_id = row[1]

            if pdb_id  == '6T28':  # Skip problematic structure
                skip = False
                continue
            if skip:
                continue

            print(f'Processing {pdb_id} chain {chain_id}')
            try:
                sequence = get_sequence(pdb_id, chain_id)
            except Exception as e:
                print(f'Error processing {pdb_id} chain {chain_id}: {e}')
                continue
            fasta_file.write(f'>{pdb_id}{chain_id}-TEST\n{sequence}\n')

Processing 6T29 chain A
Error processing 6T29 chain A: index 0 is out of bounds for axis 0 with size 0
Processing 6T2W chain A
Processing 6T41 chain A
Processing 6T6A chain A
Processing 6T6A chain B
Processing 6T6A chain C
Processing 6T6D chain A
Processing 6T6D chain B
Processing 6T6D chain C
Processing 6T6D chain D
Processing 6T6F chain A
Processing 6T6F chain B
Processing 6T8N chain A
Processing 6T8N chain B
Processing 6T8X chain A
Processing 6T8X chain B
Processing 6T8X chain C
Processing 6T8X chain D
Processing 6T8X chain E
Processing 6T8X chain F
Processing 6TCA chain B
Processing 6TCA chain D
Processing 6TCA chain F
Processing 6TCA chain H
Processing 6TCU chain A
Processing 6TD3 chain B
Processing 6TD3 chain E
Processing 6TD3 chain H
Processing 6TE2 chain A
Processing 6TEI chain A
Processing 6TEI chain B
Processing 6TEW chain A
Processing 6TFP chain A
Processing 6TFP chain B
Processing 6TFP chain C
Processing 6TFP chain D
Processing 6TFP chain E
Processing 6TFU chain A
Processin

## CATION: MMSEQS2  
Here, the `run-mmseq.sh` script needs to be run. Here, the `min-seq-id` was set to 0.3:

```
bash ./run-mmseq.sh /home/vit/Projects/LBS-pLM/data/filtering 0.3
```

In [11]:
import csv

cluster_dict = {}
with open('/home/vit/Projects/LBS-pLM/data/filtering/clusterRes_cluster.tsv', 'r') as f:
    reader = csv.reader(f, delimiter='\t')
    cluster_dict = {}
    for row in reader:
        cluster_id = row[0]
        sequence_id = row[1]
        if cluster_id not in cluster_dict:
            cluster_dict[cluster_id] = []
        cluster_dict[cluster_id].append(sequence_id)

In [36]:
overlapping_train_sequences = []
overall_count = 0
for cluster_id, sequences in cluster_dict.items():
    sequences_from_train = [seq.split('-')[0] for seq in sequences if seq.endswith('-TRAIN')]
    if any(seq.endswith('-TEST') for seq in sequences):
        if len(sequences_from_train) > 0:
            overall_count += sum([seq.endswith('-TEST') for seq in sequences])
            overlapping_train_sequences.extend(sequences_from_train)

print(f'Number of overlapping training sequences: {len(overlapping_train_sequences)}')
print('Overlapping training sequences:')
print(overlapping_train_sequences)
print(f'Overall overlapping test sequences: {overall_count}')

Number of overlapping training sequences: 57
Overlapping training sequences:
['P34947', 'P32298', 'P43250', 'Q13976', 'Q13237', 'Q14012', 'Q16644', 'Q96S44', 'P11309', 'Q9UQB9', 'Q96GD4', 'P17612', 'P23458', 'O60674', 'P52333', 'Q9UHD2', 'Q9UIK4', 'O14920', 'O15111', 'Q5S007', 'Q9P1W9', 'Q05823', 'Q15208', 'Q99986', 'P68400', 'P19784', 'P54646', 'Q13131', 'Q02750', 'Q00532', 'P43403', 'Q13557', 'Q13555', 'Q9UQM7', 'Q8WU08', 'Q15759', 'O15264', 'Q16539', 'P53778', 'P27361', 'Q8TDX7', 'P49841', 'P24941', 'P06493', 'Q00535', 'Q00534', 'P31751', 'P31749', 'P48729', 'P49674', 'P07948', 'P06239', 'P09769', 'P41240', 'Q13882', 'P06241', 'P12931']
Overall overlapping test sequences: 3645


# Create filtered dataset
Loop over the dataset CSV and filter it. Save the filtered dataset.

In [37]:
with open('/home/vit/Projects/LBS-pLM/data/full-LIGYSIS/full-train.txt', 'r') as f:
    reader = csv.reader(f, delimiter=';')
    with open('/home/vit/Projects/LBS-pLM/data/filtered-LIGYSIS/filtered-train.txt', 'w') as out_f:
        for row in reader:
            seq_id = row[0]
            if seq_id not in overlapping_train_sequences:
                out_f.write(';'.join(row) + '\n')