In [None]:
!pip install transformers
!pip install rdkit-pypi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pandas as pd

def read_smiles_csv(path):
    return pd.read_csv(path,
                       usecols=['SMILES'],
                       squeeze=True).astype(str).tolist()


In [None]:
from rdkit import Chem
from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit import RDLogger

def compare_decoded_to_original_smiles(orig_smiles, decoded_smiles, output_file=None):
    """
    Compare decoded to original SMILES strings and output a table of Tanimoto distances, along with
    binary flags for whether the strings are the same and whether the decoded string is valid SMILES.
    orig_smiles and decoded_smiles are lists or arrays of strings.
    If an output file name is provided, the table will be written to it as a CSV file.
    Returns the table as a DataFrame.

    """
    res_df = pd.DataFrame(dict(original=orig_smiles, decoded=decoded_smiles))
    is_valid = []
    is_same = []
    tani_dist = []
    accuracy = []
    count = 0
    data_size = len(orig_smiles)
    for row in res_df.itertuples():
        count = count + 1
        #compute char by char accuracy
        hit = 0
        for x, y in zip(row.original, row.decoded):
            if x == y:
                hit = hit+1
        accuracy.append((hit/len(row.original))*100)

        is_same.append(int(row.decoded == row.original))
        orig_mol = Chem.MolFromSmiles(row.original)
        if orig_mol is None:
          #Note, input may be invalid, if original SMILE string is truncated 
          is_valid.append('x')
          tani_dist.append(-1)
          continue
        dec_mol = Chem.MolFromSmiles(row.decoded)
        RDLogger.DisableLog('rdApp.*')
        if dec_mol is None:
            is_valid.append(0)
            tani_dist.append(1)
        else:
            is_valid.append(1)
            orig_fp = AllChem.GetMorganFingerprintAsBitVect(orig_mol, 2, 1024)
            dec_fp = AllChem.GetMorganFingerprintAsBitVect(dec_mol, 2, 1024)
            tani_sim = DataStructs.FingerprintSimilarity(orig_fp, dec_fp, metric=DataStructs.TanimotoSimilarity)
            tani_dist.append(1.0 - tani_sim)
    res_df['is_valid'] = is_valid
    res_df['is_same'] = is_same
    res_df['smile_accuracy'] = accuracy
    res_df['tanimoto_distance'] = tani_dist
    global_acc  = np.mean(np.array(accuracy))
    res_df['total_avg_accuracy'] = [global_acc]*len(accuracy)
    
    print("Mean global accuracy % ", global_acc)
    print("Validity % ", (is_valid.count(1)/data_size)*100)
    print("Same % ", (is_same.count(1)/data_size)*100)
    valid_tani_dist = [ t for t in tani_dist if t >= 0 ] 
    print("Average tanimoto ", np.mean(np.array(valid_tani_dist)))
    

    if output_file is not None:
        output_columns = ['original', 'decoded', 'is_valid', 'is_same', 'smile_accuracy','tanimoto_distance','total_avg_accuracy']
        res_df.to_csv(output_file, index=False, columns=output_columns)
    return(res_df)

In [None]:
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline

model = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)

Downloading:   0%|          | 0.00/631 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/6.80k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420 [00:00<?, ?B/s]

In [None]:
#Upload input files
from google.colab import files
uploaded = files.upload()

Saving input.txt to input.txt


In [None]:
out_sm = []
with open("input.txt", 'r') as f:
  f.readline()
  for smiles in f:
    #print("-------")
    
    #print(smiles)
    smiles_tokens = tokenizer.encode(smiles)
    #print(smiles_tokens)
    mask_idx = np.random.randint(0,len(smiles_tokens))
    #print(mask_idx)
    smiles_tokens[mask_idx] = tokenizer.convert_tokens_to_ids("[MASK]")
    #print(smiles_tokens)
    smiles_masked = tokenizer.decode(smiles_tokens)
    masked_smi = fill_mask(smiles_masked)

    out_sm.append(masked_smi[0].get('sequence'))
    #print(masked_smi[0].get('sequence'))
    
with open('output.txt', 'w') as fp:
    fp.write("SMILES\n")
    for sm in out_sm:
        # write each item on a new line
        fp.write("%s\n" % sm)
        #fp.write(sm)
    print('Done')

Done


In [140]:
orig_file = read_smiles_csv("input.txt")
pred_file = read_smiles_csv("output.txt")
diff_file = "smiles_metrics.csv"

print("Input/pred SMILES file sizes ", len(orig_file), " ", len(pred_file))

compare_decoded_to_original_smiles(orig_file, pred_file, diff_file)
print("Input/pred SMILES diff file saved to", diff_file)

Input/pred SMILES file sizes  256   256
Mean global accuracy %  92.25260416666666
Validity %  95.703125
Same %  73.828125
Average tanimoto  0.10577719199025387
Input/pred SMILES diff file saved to smiles_metrics.csv
