In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoTokenizer
from datasets import load_dataset

from tklearn.metrics import Accuracy, ArrayAccumulator
from tklearn.nn import Trainer, Evaluator
from tklearn.nn.callbacks import ProgbarLogger, EarlyStopping
from tklearn.nn.transformers import TransformerForSequenceClassification

In [None]:
MODEL_NAME_OR_PATH = "google-bert/bert-base-uncased"
DATASET = "yelp_review_full"

In [None]:
dataset = load_dataset(DATASET)

dataset["train"][100]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [None]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100))

In [None]:
model = TransformerForSequenceClassification.from_pretrained(MODEL_NAME_OR_PATH, num_labels=5)

model.to("mps")

optimizer = AdamW(model.parameters(), lr=1e-5)

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=16)
valid_dataloader = DataLoader(small_eval_dataset, batch_size=32)

evaluator = Evaluator(model, valid_dataloader, callbacks=[ProgbarLogger()], metrics={"acuracy": Accuracy()}, prefix="valid_")

trainer = Trainer(model, train_dataloader, optimizer=optimizer, callbacks=[ProgbarLogger(), EarlyStopping(patience=0)], evaluator=evaluator, epochs=10)

In [None]:
trainer.train()

In [None]:
# BREAK

In [None]:
embedding_accum = Evaluator(model, valid_dataloader, callbacks=[ProgbarLogger()], metrics={
    "embedding": ArrayAccumulator("embedding")
})

In [None]:
embedding_accum

In [None]:
res = embedding_accum.evaluate()

In [None]:
from __future__ import annotations

import warnings
from typing import Tuple

import matplotlib as mpl
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.backend_bases import RendererBase
from sklearn.manifold import TSNE

try:
    from umap import UMAP
except ImportError:
    UMAP = None

sns.set_theme(context="paper", style="whitegrid")

available_styles = {style: style for style in plt.style.available}

available_styles["seaborn"] = next(
    filter(
        lambda x: x.startswith("seaborn") and x.endswith("whitegrid"),
        plt.style.available,
    )
)


def get_style(style: str) -> str | None:
    return available_styles.get(style)


def embed2d(X: np.ndarray, embedder="umap") -> np.ndarray:  # noqa
    embedder = embedder.lower()
    if embedder == "umap" and UMAP is not None:
        X_embedded = UMAP(n_components=2).fit_transform(X)
    elif embedder in {"tsne", "t-sne", "umap"}:
        if embedder == "umap":
            msg = "umap is not installed, falling back to t-SNE"
            warnings.warn(msg, stacklevel=1)
        X_embedded = TSNE(n_components=2).fit_transform(X)
    else:
        msg = f"embedder {embedder} not supported"
        raise ValueError(msg)
    return X_embedded


def get_renderer(fig: plt.Figure) -> RendererBase:
    if hasattr(fig.canvas, "get_renderer"):
        return fig.canvas.get_renderer()
    elif hasattr(fig, "_get_renderer"):
        return fig._get_renderer()
    backend = mpl.get_backend()
    msg = f"could not find a renderer for the '{backend}' backend."
    raise AttributeError(msg)


def set_legend(
    handles,
    fig: plt.Figure,
    ax: plt.Axes,
    title: str,
    frameon: bool = True,
    fancybox: bool = True,
    loc: str = "upper left",
    bbox_to_anchor: Tuple[float, float] | None = None,
    ncols: int = 1,
    max_ncols: int | None = None,
):
    renderer = get_renderer(fig)
    plot_extent = ax.get_tightbbox(renderer)
    plot_height, plot_width = (
        plot_extent.height / fig.dpi,
        plot_extent.width / fig.dpi,
    )
    labels = [s.get_label() for s in handles]
    legend = ax.legend(
        handles,
        labels,
        title=title,
        frameon=frameon,
        fancybox=fancybox,
        loc=loc,
        bbox_to_anchor=bbox_to_anchor,
        ncols=ncols,
    )
    # legend width to plot width and adjust size of ax and width of plot
    legend_extent = legend.get_tightbbox(renderer)
    legend_height, legend_width = (
        legend_extent.height / fig.dpi,
        legend_extent.width / fig.dpi,
    )
    fig.set_size_inches(legend_width + plot_width, plot_height, forward=True)
    if max_ncols is None:
        return legend
    # If the legend is taller, increase the number of columns
    ncols = 2
    while (ncols <= max_ncols) and (legend_height + 1 >= plot_height):
        fig.set_size_inches(plot_width, plot_height, forward=True)
        legend.remove()
        legend = ax.legend(
            handles,
            labels,
            title=title,
            frameon=frameon,
            fancybox=fancybox,
            loc=loc,
            bbox_to_anchor=bbox_to_anchor,
            ncol=ncols,
        )
        # renderer = get_renderer(fig)
        legend_extent = legend.get_tightbbox(renderer)
        legend_height, legend_width = (
            legend_extent.height / fig.dpi,
            legend_extent.width / fig.dpi,
        )
        fig.set_size_inches(
            legend_width + plot_width, plot_height, forward=True
        )
        ncols += 1
    return legend


In [None]:
from typing import Any, Tuple

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

# from fluxai.plotting.utils import embed2d, get_style, set_legend

__all__ = [
    "plot_embedding",
]



def plot_embedding(
    data: pd.DataFrame,
    x: str = "embedding",
    y: str = "label",
    style: Any = "seaborn",
    cmap: Any = "rainbow",
    alpha: float = 0.5,
    figsize: Tuple[int, int] = (10, 10),
    embedder: str = "umap",
    dpi: float = 100,
    legend_loc: str = "upper left",
    legend_max_ncols: int = 5,
):
    x_col, y_col = x, y
    X = np.array(data[x_col].tolist())
    X_embedded = embed2d(X, embedder=embedder)
    labels = pd.Series(data[y_col]).astype("category")
    style = get_style(style)
    cmap = plt.colormaps.get_cmap(cmap)
    num_classes = len(labels.cat.categories)
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=figsize, tight_layout=True)
        fig.set_dpi(dpi)
        handles = []
        for label in labels.cat.categories:
            label_id = labels.cat.categories.get_loc(label)
            idx = np.where(labels == label)
            x, y = X_embedded[idx].T
            c = cmap(label_id / num_classes)
            scatter = ax.scatter(
                x=x, y=y, color=c, label=label, alpha=alpha, edgecolors="none"
            )
            handles.append(scatter)
        ax.legend().remove()
        set_legend(
            handles,
            fig,
            ax,
            title=y_col,
            loc=legend_loc,
            max_ncols=legend_max_ncols,
            # bbox_to_anchor=(1, 1.01),
        )
        ax.grid(True)
    return fig


In [None]:
len(res["embedding"])

In [None]:
df = pd.DataFrame({"embedding": res["embedding"].tolist() + res["embedding"].tolist(), "label": [f"Label {i}" for i in range(200)]})

In [None]:
fig = plot_embedding(df, legend_max_ncols=1)

plt.show()

In [None]:
fig = plot_embedding(df, legend_max_ncols=10)

plt.show()

In [None]:
# BREAk

In [None]:
from typing import Literal, get_origin, get_args

In [None]:
get_origin(Literal["A", "B"]) is Literal

In [None]:
get_args(Literal["A", "B"])

In [None]:
Any is Any

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.tan(x)

# Create figure and axes for the main plot
fig, ax = plt.subplots(figsize=(5, 5))

# Plot the data
ax.plot(x, y1, label='sin(x)')
ax.plot(x, y2, label='cos(x)')
ax.plot(x, y3, label='tan(x)')

# Create a temporary legend to get its width
temp_legend = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
legend_width = temp_legend.get_window_extent(fig.canvas.get_renderer()).width
legend_width_inches = legend_width / fig.dpi
ax.legend_.remove()  # Remove the temporary legend

# Get the current figure dimensions
fig_width, fig_height = fig.get_size_inches()

# Calculate new figure width
new_fig_width = fig_width + legend_width_inches

# Resize the figure
fig.set_size_inches(new_fig_width, fig_height)

# Adjust the main axes to make room for the legend
main_axes_right = 1 - (legend_width_inches / new_fig_width)
ax.set_position([ax.get_position().x0, ax.get_position().y0, main_axes_right - ax.get_position().x0, ax.get_position().height])

# Add the legend
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Sample data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.tan(x)

# Create figure and axes for the main plot
fig, ax = plt.subplots(figsize=(5, 5))

# Plot the data
ax.plot(x, y1, label='sin(x)')
ax.plot(x, y2, label='cos(x)')
ax.plot(x, y3, label='tan(x)')

# # Create a temporary legend to get its width
# temp_legend = ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# legend_width = temp_legend.get_window_extent(fig.canvas.get_renderer()).width
# legend_width_inches = legend_width / fig.dpi
# ax.legend_.remove()  # Remove the temporary legend

# # Get the current figure dimensions
# fig_width, fig_height = fig.get_size_inches()

# # Calculate new figure width
# new_fig_width = fig_width + legend_width_inches

# # Resize the figure
# fig.set_size_inches(new_fig_width, fig_height)

# # Adjust the main axes to make room for the legend
# main_axes_right = 1 - (legend_width_inches / new_fig_width)
# ax.set_position([ax.get_position().x0, ax.get_position().y0, main_axes_right - ax.get_position().x0, ax.get_position().height])

# Add the legend
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

plt.show()

In [None]:
# model

In [None]:
for name, param in model.named_parameters():
    if name.startswith("classifier."):
        # ignore classifier layer since it dynamically changes
        continue
    print(name)