## SPEACH_AF_SCAN
#### Sampling Protein Ensembles and Conformational Heterogeneity with Alphafold2
***
This notebook will take as input an MSA (a3m format) and a protein structure (pdb format) that were generated by an initial ColabFold run and generate modified MSAs with alanine mutagenesis.
Requires: Bio, numpy, and copy

Inputs needed:
> the a3m file and the pdb file (path and fullname) <br>
> the output directory (needs the final slash and already exist) <br>
> the basename for the MSA output

Adjustable variables:
> the size of the sliding window <br>
> the number of repeats per sliding window <br>
> the proximity in angstrom for the interacting residues <br>

In [None]:
a3m_filename = ''         # MSA file in a3m format
pdb_filename = ''         # pdb file
output_dir = ''           # directory for the output
output_base = ''          # base filename for the output MSAs
window = 10               # the size of the sliding window
no_out = 3                # the numner of repeats for each window
proximity = 4             # interaction size in angstroms

In [None]:
from Bio import SeqIO
from Bio.Seq import MutableSeq
from Bio.Seq import Seq
import Bio.PDB as BP
import numpy as np
import copy

# read in pdb
parser = BP.PDBParser()
data = parser.get_structure('mod',pdb_filename)
model = data.get_models()
models = list(model)
chains = list(models[0].get_chains())
residue = list(chains[0].get_residues())

# get all atom positions
positions = []
residue_list = []
nor = len(residue)
for m in range(nor):
    atoms = list(residue[m].get_atoms())
    noa = len(atoms)
    for n in range(noa):
        residue_list = np.append(residue_list,residue[m].get_full_id()[3][1])
        positions.append(atoms[n].get_vector().get_array())

# calculate the distance matrix for all atoms        
nop = len(positions)
pdist = np.zeros([nop,nop])
for m in range(nop):
    temp = positions[m]-positions
    pdist[m,:] = np.sqrt(temp[:,0]**2+temp[:,1]**2+temp[:,2]**2)

# determine the residues to scan over and the interacting partners and write to file
b_factor = []
resindex = [];
for m in range(0,nor):
    if residue[m].has_id("CA"):
        CA = residue[m]["CA"]
        b_factor.append(CA.get_bfactor())
        resindex.append(residue[m].get_full_id()[3][1])
mb = np.mean(b_factor)
test = np.where(b_factor > mb)[0]
nor = test[-1]-test[0]
start = resindex[test[0]];
res_no, res_rem = divmod(nor,window)
res_use = []
f = open(output_dir + output_base + '.txt','w')
f.write('{} - {}\n'.format(test[0],test[-1]))
for p in range(0,res_no):
    res_to_use = []
    min_res = p*window+start
    max_res = p*window+start+window
    min_res_pos = np.where(residue_list[:] == min_res)[0][0]
    max_res_pos = np.where(residue_list[:] == max_res)[0][-1]
    for m in range(min_res_pos,max_res_pos):
        temp_dist = np.where(pdist[m,:] < proximity)[0]
        for n in range(len(temp_dist)):
            temp_res = np.int(residue_list[temp_dist[n]])
            if not(min_res-4 <= temp_res <= max_res+4):
                res_to_use.append(np.int(residue_list[m]))
                res_to_use.append(temp_res)
    res_to_use = np.unique(res_to_use).tolist()
    res_use.append(res_to_use)
    f.write('{:02} -- {}-{}: {}\n'.format(p+1,min_res,max_res,res_to_use))
for p in range(res_no,res_no+1):
    res_to_use = []
    min_res = max_res+1
    max_res = min_res+res_rem-1
    min_res_pos = np.where(residue_list[:] == min_res)[0][0]
    max_res_pos = np.where(residue_list[:] == max_res)[0][-1]
    for m in range(min_res_pos,max_res_pos):
        temp_dist = np.where(pdist[m,:] < proximity)[0]
        for n in range(len(temp_dist)):
            temp_res = np.int(residue_list[temp_dist[n]])
            if not(min_res-4 <= temp_res <= max_res+4):
                res_to_use.append(np.int(residue_list[m]))
                res_to_use.append(temp_res)
    res_to_use = np.unique(res_to_use).tolist()
    res_use.append(res_to_use)
    f.write('{:02} -- {}-{}: {}\n'.format(p+1,min_res,max_res,res_to_use))
f.close()    

# read in MSA
# check for hashtag on first line
f = open(a3m_filename,'r')
first = f.readline()
f.close()
if first[0] == '#':
    hash_there = True
else:
    hash_there = False
records = list(SeqIO.parse(a3m_filename, "fasta"))

# save unmodified MSA
alidatac = copy.deepcopy(records)
lines = []
nos = len(alidatac);
los = len(alidatac[0].seq);
lines_no, remainder = divmod(los,100)
for m in range(0,nos):
    lines.append('>'+alidatac[m].description+'\n')
    n = -1; #if lines_no == 0, n = len(temp_dist)
    for n in range(0,lines_no):
        lines.append(alidatac[m].seq[n*100:(n+1)*100].__str__().upper()+'\n')
    lines.append(alidatac[m].seq[(n+1)*100:(n+1)*100+remainder].__str__().upper()+'\n')
out = "".join(lines)
for n in range(0,no_out):
    output_file = output_dir + output_base + '_{:02}_{:02}.a3m'.format(0,n+1)
    if hash_there:
        f = open(output_file,'w')
        f.writelines(first)
        f.close()
    with open(output_file,"a") as tmp_upload:
        tmp_upload.writelines(out)

# step over sliding window and change MSA
for p in range(0,res_no+1):
    res_c = np.array(res_use[p])
    no_res = len(res_c)
    change_to = ''
    for m in range(0,no_res):
        change_to += 'A'
    alidatac = copy.deepcopy(records)
    nos = len(alidatac);
    los = len(alidatac[0].seq);
    lines = []
    lines_no, remainder = divmod(los,100)
    for m in range(0,nos):
        temp = alidatac[m].seq.__str__()
        ltemp = len(temp);
        words = zip(temp[0:].upper(),temp[0:])
        isup = [int(i==j) for i,j in words]
        count = 0;
        for n in range(0,ltemp):
            count = count + isup[n];
            if np.isin(count,res_c):
                inds = np.where(res_c == count)[0][0]
                if temp[n] != '-':
                    temp = temp[:n] + change_to[inds] + temp[n+1:]
        alidatac[m].seq = Seq(temp);
    
        lines.append('>'+alidatac[m].description+'\n')
        n = -1; # if lines_no == 0, n = ltemp-1
        for n in range(0,lines_no):
            lines.append(alidatac[m].seq[n*100:(n+1)*100].__str__().upper()+'\n')
        lines.append(alidatac[m].seq[(n+1)*100:(n+1)*100+remainder].__str__().upper()+'\n')

    out = "".join(lines)
    for n in range(0,no_out):
        output_file = output_dir + output_base + '_{:02}_{:02}.a3m'.format(p+1,n+1)
        if hash_there:
            f = open(output_file,'w')
            f.writelines(first)
            f.close()
        with open(output_file,"a") as tmp_upload:
            tmp_upload.writelines(out)        

#### License
This notebook and source code is licensed under MIT.
***
Aug. 22, 2022 Richard Stein