In [1]:
!pip install umap-learn plotly

In [2]:
import numpy as np

with open("xc_list_torch.npy", "rb") as f:
    xc_list = np.load(f, allow_pickle=True)
with open("xq_list_torch.npy", "rb") as f:
    xq_list = np.load(f)

In [3]:
xq_list.shape

(15, 768)

In [4]:
xc_values = []
xc_idx = []
batch = []
for i, xc_batch in enumerate(xc_list):
    for xc in xc_batch:
        xc_values.append(xc[1])
        xc_idx.append(xc[0])
        batch.append(i)

In [5]:
xc_values = np.array(xc_values)
xc_values.shape

(160, 768)

In [6]:
x_all = np.vstack([xc_values, xq_list])
x_all.shape

(175, 768)

In [7]:
for i in range(xq_list.shape[0]):
    batch.append(i)
xq_len = xq_list.shape[0]
xq_len

15

---

### Alternative Dim Reduction Techniques

#### PCA

```python
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
u_xc = pca.fit_transform(x_all)
```

#### tSNE

```python
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=10, n_iter=4000)
u_xc = tsne.fit_transform(x_all)
```

Neither of these seem to perform as well as UMAP for this dataset.

---

In [8]:
import umap

fit = umap.UMAP(
    n_neighbors=90,
    n_components=3,
    metric="cosine",
    min_dist=0.99
)

u_xc = fit.fit_transform(x_all)

In [9]:
import plotly.graph_objects as go

contexts = go.Scatter3d(
    x=u_xc[:-xq_len,0],
    y=u_xc[:-xq_len,1],
    z=u_xc[:-xq_len,2],
    marker=dict(
        color=batch[:-xq_len],
    ),
    line=dict(
        width=0,
        color='rgba(0,0,0,0)'
    ),
    text=[f"{b}: {i}" for i, b in zip(xc_idx, batch[:-xq_len])]
)

In [10]:
queries = go.Scatter3d(
    x=u_xc[-xq_len:,0],
    y=u_xc[-xq_len:,1],
    z=u_xc[-xq_len:,2],
    marker=dict(
        color=['#000000']
    ),
    line=dict(
        width=0,
        color='rgba(0,0,0,1)'
    ),
    text=[f"{b}" for b in batch[-xq_len:]]
)

Also add origin point to give an idea of the angular similarity between points.

In [13]:
origin_point = np.zeros((1, x_all.shape[1]))
# map to same vector space
origin_point = fit.transform(origin_point)

origin = go.Scatter3d(
    x=[0],
    y=[0],
    z=[0],
    mode='markers',
    marker=dict(
        size=10,
        color='#000000',
        symbol='cross'
    ),
    text=['origin']
)

In [14]:
go.Figure(data=[queries, contexts, origin])

---