In [None]:
import pandas as pd
import plotly.express as px

Results from *Michael Heinzinger, Konstantin Weissenow, Joaquin Gomez Sanchez, Adrian Henkel, Milot Mirdita, Martin Steinegger, Burkhard Rost, Bilingual language model for protein sequence and structure, NAR Genomics and Bioinformatics, Volume 6, Issue 4, December 2024, lqae150, https://doi.org/10.1093/nargab/lqae150*

**CATH**

hierarchical classification scheme for protein domain structures, based on **C**lass, **A**rchitecture, **T**opology (or fold), and **H**omologous superfamily

Class (C): Broad structural categorization based on secondary structure composition.
Main classes:
- Mainly alpha (α-helices dominate)
- Mainly beta (β-sheets dominate)
- Alpha-beta (mixed α/β structures)
- Few secondary structures (little secondary structure)

Architecture (A): Arrangements of secondary structures in 3D space, independent of connectivity.

Topology (T) / Fold: Describes the connectivity of secondary structures, defining unique folds.

Homologous Superfamily (H): Groups protein domains that share a common ancestor and have structural and functional similarities.

In [None]:
# from Supplements Table S1 "ProstT5 embeddings for fold classification (contrastive learning)"

contrastive_learning_results = {
    "ESM-1b": [87, 4, 68, 6, 59, 7, 70, 7, 71, 6],
    "Ankh":	[87, 4, 76, 6, 64, 6, 75, 7, 75, 6],
    "ProtT5": [89, 4, 75, 6, 64, 6, 76, 6, 76, 6],
    "ProstT5(AA)": [89, 4, 81, 5, 70, 6, 83, 6, 81, 5],
    "ProstT5(p3Di)": [88, 4, 77, 5, 69, 6, 80, 6, 79, 5],
    "ProstT5(3Di)":	[90, 4,	78, 5, 69, 6, 74, 7, 78, 6],
    "ProstT5(cat)":	[91, 4, 79, 5, 72, 6, 84, 6, 81, 5] # concatenation of ProstT5(AA) and ProstT5(p3Di)
}

contrastive_learning_results_column_names = ["Model","ctl_C","ctl_C_ci", "ctl_A", "ctl_A_ci", "ctl_T", "ctl_T_ci", "ctl_H","ctl_H_ci", "ctl_Mean", "ctl_Mean_ci"]

# trun into a pandas dataframe
contrastive_learning_results_df = pd.DataFrame(contrastive_learning_results).T
# reset the index
contrastive_learning_results_df.reset_index(inplace=True)
# rename the columns
contrastive_learning_results_df.columns = contrastive_learning_results_column_names
contrastive_learning_results_df

In [None]:
# from Table 1 "Classification of proteins into CATH hierarchy (folds)*"

unsupervised_results = {
    "ESM-1b": [79, 5, 61, 6, 50, 7, 57, 8, 62, 7],
    "Ankh": [84, 5, 69, 6, 60, 7, 67, 8, 70, 6],
    "ProtT5":	[84, 5, 67, 6, 57, 6, 64, 8, 68, 6],
    "ProstT5(AA)":	[85, 5, 74, 6, 64, 6, 69, 7, 73, 6],
    "ProstT5(p3Di)":    [85, 5, 71, 6, 60, 7, 73, 7, 72, 6],
    "ProstT5(3Di)":	[90, 4, 77, 6, 65, 6, 75, 7, 77, 6],
    "ProstT5(cat)":	[88, 4, 74, 6, 65, 7, 74, 7, 75, 6]
}

unsupervised_results_column_names = ["Model","C","C_ci", "A", "A_ci", "T", "T_ci", "H","H_ci", "Mean", "Mean_ci"]

# trun into a pandas dataframe
unsupervised_results_df = pd.DataFrame(unsupervised_results).T
# reset the index
unsupervised_results_df.reset_index(inplace=True)
# rename the columns
unsupervised_results_df.columns = unsupervised_results_column_names
unsupervised_results_df

In [None]:
# combine the two dataframes
combined_results_df = pd.concat([contrastive_learning_results_df, unsupervised_results_df], axis=1)
# drop the duplicate column
combined_results_df = combined_results_df.loc[:,~combined_results_df.columns.duplicated()]
combined_results_df

In [None]:
# plot the results as scatter plot compraing clt to unsupervised

fig = px.scatter(combined_results_df, x="Mean", y="ctl_Mean", text="Model")
fig.update_traces(textposition='top center')
fig.update_layout(title="Protein Fold Classification Accuracy (Mean)",
                  template="plotly_white",
                  font=dict(color="black", family="Arial"),
                  width=500,
                  height=500,
                  xaxis_title="Unsupervised Learning",
                  yaxis_title="Contrastive Learning")
# adapt axis range
fig.update_xaxes(range=[60, 85])
fig.update_yaxes(range=[60, 85])
fig.show()



In [None]:
# plot the results as scatter plot compraing clt to unsupervised

fig = px.scatter(combined_results_df, x="Mean", y="ctl_Mean", text="Model",
                 # add the error bars
                 error_x="Mean_ci", error_y="ctl_Mean_ci",)
fig.update_traces(textposition='top center',
                  error_x=dict(color="lightgrey"),
                  error_y=dict(color="lightgrey") )
fig.update_layout(title="Protein Fold Classification Accuracy (Mean)",
                  template="plotly_white",
                  font=dict(color="black", family="Arial"),
                  width=500,
                  height=500,
                  xaxis_title="Unsupervised Learning",
                  yaxis_title="Contrastive Learning")
# adapt the range of the x and y axis
fig.update_xaxes(range=[60, 85])
fig.update_yaxes(range=[60, 85])
fig.show()



In [None]:
# plot the results as scatter plot compraing clt to unsupervised

fig = px.scatter(combined_results_df, x="A", y="ctl_A", text="Model",
                 # add the error bars
                 error_x="A_ci", error_y="ctl_A_ci",)
fig.update_traces(textposition='top center',
                  error_x=dict(color="lightgrey"),
                  error_y=dict(color="lightgrey") )
fig.update_layout(title="Protein Fold Classification Accuracy (CATH-A)",
                  template="plotly_white",
                  font=dict(color="black", family="Arial"),
                  width=500,
                  height=500,
                  xaxis_title="Unsupervised Learning",
                  yaxis_title="Contrastive Learning")
# adapt the range of the x and y axis
fig.update_xaxes(range=[50, 85])
fig.update_yaxes(range=[50, 85])
fig.show()



In [None]:
# plot the results as scatter plot compraing clt to unsupervised

fig = px.scatter(combined_results_df, x="C", y="ctl_C", text="Model",
                 # add the error bars
                 error_x="C_ci", error_y="ctl_C_ci",)
fig.update_traces(textposition='top center',
                  error_x=dict(color="lightgrey"),
                  error_y=dict(color="lightgrey") )
fig.update_layout(title="Protein Fold Classification Accuracy (CATH-C)",
                  template="plotly_white",
                  font=dict(color="black", family="Arial"),
                  width=500,
                  height=500,
                  xaxis_title="Unsupervised Learning",
                  yaxis_title="Contrastive Learning")
# adapt the range of the x and y axis
fig.update_xaxes(range=[70, 100])
fig.update_yaxes(range=[70, 100])
fig.show()



In [None]:
# plot the results as scatter plot compraing clt to unsupervised

fig = px.scatter(combined_results_df, x="T", y="ctl_T", text="Model",
                 # add the error bars
                 error_x="T_ci", error_y="ctl_T_ci",)
fig.update_traces(textposition='top center',
                  error_x=dict(color="lightgrey"),
                  error_y=dict(color="lightgrey") )
fig.update_layout(title="Protein Fold Classification Accuracy (CATH-T)",
                  template="plotly_white",
                  font=dict(color="black", family="Arial"),
                  width=500,
                  height=500,
                  xaxis_title="Unsupervised Learning",
                  yaxis_title="Contrastive Learning")
# adapt the range of the x and y axis
fig.update_xaxes(range=[40, 80])
fig.update_yaxes(range=[40, 80])
fig.show()



In [None]:
# plot the results as scatter plot compraing clt to unsupervised

fig = px.scatter(combined_results_df, x="H", y="ctl_H", text="Model",
                 # add the error bars
                 error_x="H_ci", error_y="ctl_H_ci",)
fig.update_traces(textposition='top center',
                  error_x=dict(color="lightgrey"),
                  error_y=dict(color="lightgrey") )
fig.update_layout(title="Protein Fold Classification Accuracy (CATH-H)",
                  template="plotly_white",
                  font=dict(color="black", family="Arial"),
                  width=500,
                  height=500,
                  xaxis_title="Unsupervised Learning",
                  yaxis_title="Contrastive Learning")
# adapt the range of the x and y axis
fig.update_xaxes(range=[50, 90])
fig.update_yaxes(range=[50, 90])
fig.show()
