<a href="https://colab.research.google.com/github/sgbaird-mwe/mat_discover/blob/main/examples/elmd_densmap_cluster_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ElMD DensMAP Clustering Visualization
This example is to show how to plot only the ElMD/DensMAP embeddings (colored by cluster label or property of your choice) rather than create all the plots. This will produce both an interactive Plotly figure as well as a "paper-ready" PNG image. An Plotly `.html` file and the `.png` image are exported automatically to the `./figures/` directory, which is created relative to your current working directory if it doesn't already exist. Please note that if you are using Google Colab, the built-in storage doesn't persist, so you'll need to e.g. mount Google Drive and save to there if you want to keep the figures.

To clarify some of the acronyms:
- [Element Mover's Distance (ElMD)](https://github.com/lrcfmd/ElMD)
- [Density-preserving Uniform Manifold Approximation and Projection (DensMAP)](https://umap-learn.readthedocs.io/en/latest/densmap_demo.html)

[HDBSCAN* is used for clustering](https://umap-learn.readthedocs.io/en/latest/clustering.html).

The basic workflow is:
- install dependencies (⚠ WARNING: possibly need to restart the runtime after installation ⚠)
- Load some data
- fit Discover()
- predict on validation data
- plot

In [1]:
%%time
!pip install -U pip
!pip install -UI mat_discover #-UI not necessary when not on Colab

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-22.2.2-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 31.9 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.2.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mat_discover
  Downloading mat_discover-2.2.7-py2.py3-none-any.whl (38 kB)
Collecting plotly>=5.6.0
  Downloading plotly-5.10.0-py2.py3-none-any.whl (15.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.2/15.2 MB[0m [31m61.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cython
  Downloading Cython-0.29.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (1.9 MB)
[2K     [90m━━━━━━━━━━

## Setup

### Imports

In [2]:
from os.path import join
import pandas as pd
from crabnet.data.materials_data import elasticity
from mat_discover.mat_discover_ import Discover
from mat_discover.utils.pareto import pareto_plot

In [3]:
# %% setup
# set dummy to True for a quicker run --> small dataset, MDS instead of UMAP
dummy = False
disc = Discover(
    dummy_run=dummy, pred_weight=0
)  # pred_weight=0 means no target property regression
train_df, val_df = disc.data(elasticity, fname="train.csv", dummy=dummy)
train_df  # only formula and target columns required

Unnamed: 0,formula,index,target,count
4002,Li2CuS2,"(11237, 11254)",41.0,2
334,As2O5,"(5126,)",93.0,1
1613,Co2B4Mo,"(8018,)",303.0,1
4329,LiCaF3,"(10205,)",85.0,1
792,Be2CoPt,"(7684,)",207.0,1
...,...,...,...,...
5734,NaPt2,"(2642,)",124.0,1
5191,Mn2GaW,"(7848,)",220.0,1
5390,MnSnPd2,"(6509,)",133.0,1
860,BeNi3,"(11839,)",63.0,1


In [4]:
disc.fit(
    train_df
)  # fitting is quick since property regression model is skipped (`pred_weight=0`)

In [5]:
%%time
score = disc.predict(val_df, umap_random_state=42)

val RMSE:  0.0
Fitting mod_petti kernel matrix
Constructing distances
[fit-wasserstein]
Elapsed: 21.00964




using precomputed metric; inverse_transform will be unavailable



[fit-UMAP]
Elapsed: 87.10397




using precomputed metric; inverse_transform will be unavailable



[fit-vis-UMAP]
Elapsed: 42.08381

[DensMAP]
Elapsed: 129.18966

[HDBSCAN*]
Elapsed: 0.14847

[pdf-summation]
Elapsed: 12.271

[gridded-pdf-summation]
Elapsed: 12.42868

[train-val-pdf-summation]
Elapsed: 3.20845

[nearest-neighbor-properties]
Elapsed: 2.93471

CPU times: user 2min 47s, sys: 3.1 s, total: 2min 51s
Wall time: 2min 49s


In [6]:
# Interactive scatter plot colored by clusters
x = "DensMAP Dim. 1"
y = "DensMAP Dim. 2"
umap_df = pd.DataFrame(
    {
        x: disc.std_emb[:, 0],
        y: disc.std_emb[:, 1],
        "cluster ID": disc.labels,
        "formula": disc.all_formula,
    }
)
fig, _ = pareto_plot(
    umap_df,
    x=x,
    y=y,
    color="cluster ID",
    fpath=join(disc.figure_dir, "px-umap-cluster-scatter"),
    pareto_front=False,
    parity_type=None,
)
fig

## Example Output

![example-plotly-output](https://github.com/sparks-baird/mat_discover/blob/main/examples/figures/cluster-example-output.png?raw=1)

(when you run this notebook, the above plot will be interactive and a Plotly `.html` file and a [matplotlibified](https://github.com/sparks-baird/mat_discover/blob/4abeea75664d291275900016e8a15d2cacd63838/mat_discover/utils/plotting.py#L250) "paper-ready" `.png` image will also be exported to `./figures/`)