# Disclaimer
This material was prepared as an account of work sponsored by an agency of the United States Government.  Neither the United States Government nor the United States Department of Energy, nor Battelle, nor any of their employees, nor any jurisdiction or organization that has cooperated in the development of these materials, makes any warranty, express or implied, or assumes any legal liability or responsibility for the accuracy, completeness, or usefulness or any information, apparatus, product, software, or process disclosed, or represents that its use would not infringe privately owned rights. Reference herein to any specific commercial product, process, or service by trade name, trademark, manufacturer, or otherwise does not necessarily constitute or imply its endorsement, recommendation, or favoring by the United States Government or any agency thereof, or Battelle Memorial Institute. The views and opinions of authors expressed herein do not necessarily state or reflect those of the United States Government or any agency thereof.

PACIFIC NORTHWEST NATIONAL LABORATORY operated by BATTELLE for the UNITED STATES DEPARTMENT OF ENERGY under Contract DE-AC05-76RL01830.

In [None]:
from pathlib import Path

import datasets
import matplotlib.pyplot as plt
import numpy as np
import umap

from nukelm.analyze.BERTopic import BERTopic
from nukelm.analyze.umap_comparisons import PLOT_KWARGS, UMAP_KWARGS, plot_points


PROJECT_DIR = Path.cwd().parent
output_dir = PROJECT_DIR / "data" / "08_reporting" / "bertopic"
output_dir.mkdir(exist_ok=True)

AGG_METHOD = "CLS"

In [None]:
dataset_trained = datasets.load_from_disk(str(PROJECT_DIR / "data" / "07_model_output" / "roberta-large-trained-1"))
dataset_ots = datasets.load_from_disk(str(PROJECT_DIR / "data" / "07_model_output" / "roberta-large-ots-1"))

In [None]:
mapper_trained = umap.UMAP(**UMAP_KWARGS).fit(dataset_trained[AGG_METHOD])

In [None]:
mapper_ots = umap.UMAP(**UMAP_KWARGS).fit(dataset_ots[AGG_METHOD])

In [None]:
points_trained = mapper_trained.transform(dataset_trained[AGG_METHOD])
points_ots = mapper_ots.transform(dataset_ots[AGG_METHOD])

# BERTopic applied to model with continued pre-training


In [None]:
BERTOPIC_KWARGS = {
    "n_neighbors": 15,
    "n_components": 100,
    "min_dist": 0.1,
    "umap_metric": "euclidean",
    "random_state": 42,
    "min_cluster_size": 25,
    "min_samples": None,
    "cluster_selection_epsilon": 0.0,
    "hdbscan_metric": "euclidean",
    "alpha": 1.0,
    "cluster_selection_method": "eom",
    "verbose": True,
}

In [None]:
model_trained = BERTopic(**BERTOPIC_KWARGS)
labels_trained, _ = model_trained.fit_transform(dataset_trained["text"], np.array(dataset_trained[AGG_METHOD]))
labels_ots = labels_trained  # plot with labels from pre-trained model

In [None]:
labels_set_trained = set(labels_trained + labels_ots)
label_map_trained = {i: f"Cluster {i + 1: 2d}" for i in range(max(labels_set_trained) + 1)}
label_map_trained[-1] = "None"

In [None]:
{f"Cluster {i+1: 2d}" if i + 1 > 0 else "None": model_trained.get_topic(i) for i in labels_set_trained}

In [None]:
fig_trained = plot_points(
    (points_trained, points_ots),
    (labels_trained, labels_ots),
    (r"\textsc{NukeLM}", r"\textsc{RoBERTa} Large"),
    label_map_trained,
    True,
    **PLOT_KWARGS,
)
fig_trained.savefig(output_dir / "trained-clusters.png", dpi=300)

# BERTopic applied to model without continued pre-training

In [None]:
BERTOPIC_KWARGS = {
    "n_neighbors": 15,
    "n_components": 100,
    "min_dist": 0.1,
    "umap_metric": "euclidean",
    "random_state": 42,
    "min_cluster_size": 25,
    "min_samples": None,
    "cluster_selection_epsilon": 0.0,
    "hdbscan_metric": "euclidean",
    "alpha": 1.0,
    "cluster_selection_method": "eom",
    "verbose": True,
}

In [None]:
model_ots = BERTopic(**BERTOPIC_KWARGS)
labels_ots, _ = model_ots.fit_transform(dataset_ots["text"], np.array(dataset_ots[AGG_METHOD]))
labels_trained = labels_ots  # plot with labels from off-the-shelf model

In [None]:
labels_set_ots = set(labels_trained + labels_ots)
label_map_ots = {i: f"Cluster {i + 1: 2d}" for i in range(max(labels_set_ots) + 1)}
label_map_ots[-1] = "None"

In [None]:
{f"Cluster {i+1: 2d}" if i + 1 > 0 else "None": model_ots.get_topic(i) for i in labels_set_ots}

In [None]:
fig_ots = plot_points(
    (points_trained, points_ots),
    (labels_trained, labels_ots),
    (r"\textsc{NukeLM}", r"\textsc{RoBERTa} Large"),
    label_map_ots,
    True,
    **PLOT_KWARGS,
)
fig_ots.savefig(output_dir / "ots-clusters.png", dpi=300)

In [None]:
# fig = plot_points(
#     (points_trained, points_ots),
#     (_labels_trained, _labels_ots),
#     None,
# #     (r"\textsc{NukeLM}", r"\textsc{RoBERTa} Large"),
#     label_map,
#     False,
#     **PLOT_KWARGS
# )

In [None]:
# fig.axes[0].set_xlim(-2.5, 17.5)
# fig.axes[1].set_xlim(-2.5, 17.5)
# fig.axes[0].set_ylim(-2.5, 17.5)
# fig.axes[1].set_ylim(-2.5, 17.5)
# fig.axes[0].set_xticks([0, 5, 10, 15])
# fig.axes[0].set_yticks([0, 5, 10, 15])
# fig.axes[1].set_xticks([0, 5, 10, 15])
# fig.axes[1].set_yticks([0, 5, 10, 15])
# fig

In [None]:
# fig.axes[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))

# fig

In [None]:
# fig.savefig("clusters.png", dpi=300, bbox_inches="tight")

In [None]:
# fig = plot_points(
#     (points_trained, points_ots),
#     (_labels_trained, _labels_ots),
#     None,
# #     (r"\textsc{NukeLM}", r"\textsc{RoBERTa} Large"),
#     label_map,
#     False,
#     **PLOT_KWARGS
# )

In [None]:
# fig = plot_points(
#     (points_trained, points_ots),
#     (dataset_trained["label"], dataset_ots["label"]),
#     (r"\textsc{NukeLM}", r"\textsc{RoBERTa} Large"),
#     None,
#     False,
#     **PLOT_KWARGS
# )

In [None]:
# fig.axes[0].set_xlim(-2.5, 17.5)
# fig.axes[1].set_xlim(-2.5, 17.5)
# fig.axes[0].set_ylim(-2.5, 17.5)
# fig.axes[1].set_ylim(-2.5, 17.5)
# fig.axes[0].set_xticks([0, 5, 10, 15])
# fig.axes[0].set_yticks([0, 5, 10, 15])
# fig.axes[1].set_xticks([0, 5, 10, 15])
# fig.axes[1].set_yticks([0, 5, 10, 15])

# fig.axes[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))

# fig

In [None]:
# fig.savefig("umap.png", dpi=300, bbox_inches="tight")

# Final plot for publication

In [None]:
LABEL_MAP = {
    "nuke": "NFC-Related",
    "not-nuke": "Other",
}
PLOT_KWARGS = {
    "linestyle": "None",
    "marker": ".",
    "alpha": 0.5,
}

In [None]:
_labels_trained = [label if label >= 0 else int(1e5) for label in labels_trained]
_labels_ots = [label if label >= 0 else int(1e5) for label in labels_ots]

In [None]:
labels_set = set(_labels_trained + _labels_ots)
label_map = {i: f"Cluster {i + 1: 2d}" for i in range(max(labels_set) + 1)}

label_map[int(1e5)] = "Outlier"

In [None]:
points = (points_trained, points_ots)[::-1]
labels = ((dataset_trained["label"], dataset_ots["label"])[::-1], (_labels_trained, _labels_ots)[::-1])
label_maps = (LABEL_MAP, label_map)
titles = (r"\textsc{NukeLM}", r"\textsc{RoBERTa} Large")[::-1]

In [None]:
from matplotlib import rc


rc("text", usetex=True)

fig, axes = plt.subplots(2, 2, figsize=(7, 7))

for i in range(2):
    for j in range(2):
        _points = points[j]
        ax = axes[i, j]
        _labels = labels[i][j]
        unique_labels = sorted(list(set(_labels)))
        idx = {}
        for class_name in unique_labels:
            idx[class_name] = [i for i, label in enumerate(_labels) if label == class_name]
        for class_name in unique_labels:
            ax.plot(
                _points[idx[class_name], 0],
                _points[idx[class_name], 1],
                label=label_maps[i][class_name],
                **PLOT_KWARGS,
            )
        if i == 0:
            ax.set_title(titles[j])
        ax.legend(loc="upper right")  # loc='center left', bbox_to_anchor=(1, 0.5))
        if j == 0:
            if i == 0:
                ax.set_ylabel("NFC Labels")
            if i == 1:
                ax.set_ylabel("BERTopic Cluster Labels")
        ax.set_xlim(-2.5, 25)
        ax.set_ylim(-2.5, 18)
        ax.set_xticks([0, 5, 10, 15, 20, 25])
        ax.set_yticks([0, 5, 10, 15])

In [None]:
fig.savefig(output_dir / "combined-plots.png", dpi=300, bbox_inches="tight")