In [None]:
import pandas as pd
import spacy
nlp = spacy.load("en_core_web_sm")

data = pd.read_csv('data/winemag-data_first150k.csv')

In [None]:
data = data.groupby(['country', 'variety'], group_keys=False, ).apply(lambda x: x.sample(n=10, random_state=42, replace=True))
data = data.reset_index(drop=True)

In [None]:
data['variety'] = data['variety'].str.lower()
data['country'] = data['country'].str.lower()

combined_data = data.groupby(['variety', 'country'])['description'].agg(lambda x: ' '.join(x))

In [None]:
ignored_words = [
    "%",
    "wine",
    "wines",
    "fruit",
    "fruits",
    "flavor",
    "flavour",
    "flavors",
    "flavours",
    "finish",
    "palate",
    "aromas",
    "tannins",
    "nose",
    "notes",
    "structure",
    "variety",
    "feel",
    "hint",
    "little",
    "nice",
    "color",
    "colors",
    "colour",
    "colours",
    "note"

]
def get_adjectives_spacy(words):
    wine_vibes = []
    doc = nlp(words)

    for token in doc:
        if token.pos_ in ["ADJ", "NOUN"] and token.text not in ignored_words:
            wine_vibes.append(token.text)

    return wine_vibes
adj = combined_data.agg(lambda x: get_adjectives_spacy(x)).reset_index()


In [None]:
all_user_input_adj

In [None]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

user_input_variety = 'chenin blanc'
user_input_country = 'france'
user_input_adj = adj[(adj['variety']==user_input_variety) & 
                     (adj['country']==user_input_country)]
all_user_input_adj = user_input_adj['description'].explode().reset_index()
top_user_input_adj = all_user_input_adj.value_counts().head(10) # these will be text plotted



# Get all the indexes for the top n terms by frequency
top_n_index = all_user_input_adj.index[all_user_input_adj['description'].isin(top_user_input_adj.reset_index()['description'])]
# Sentences are encoded by calling model.encode()

all_user_input_adj = all_user_input_adj.value_counts().head(20).reset_index() # take top 20 words only to save space on vis
user_input_emb = model.encode(all_user_input_adj['description'])
# top_user_input_emb = model.encode(top_user_input_adj.index)

# Only plots the top n words for each wine & country (otherwise the vis is too busy)
adj_exploded = adj.explode('description')
adj_counts = adj_exploded.groupby(['variety', 'country'])['description'].value_counts()
adj_top = adj_counts.groupby(level=[0,1]).head(2).reset_index()

# Get all adjectives
all_adj = list(set(adj_top['description'].explode().to_list()))
all_adj = [word for word in all_adj if str(word) != 'nan']
all_adj_emb = model.encode(all_adj)

In [None]:
all_user_input_adj

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
pca = PCA(n_components=3)
all_adj_embed_3d = pca.fit_transform(all_adj_emb)
user_input_emb_3d = pca.fit_transform(user_input_emb)


In [None]:
all_user_input_adj['description'][:10]

In [None]:
import plotly.graph_objects as go

x_coords, y_coords, z_coords = zip(*all_adj_embed_3d)
selected_x_coords, selected_y_coords, selected_z_coords = zip(*user_input_emb_3d)

# Get coords for text
text_x_coords = selected_x_coords[:10]
text_y_coords = selected_y_coords[:10]
text_z_coords = selected_z_coords[:10]
text_labels = all_user_input_adj['description'][:10]



# Find the minimum and maximum values for each axis
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
z_min, z_max = min(z_coords), max(z_coords)

# Extend the axis range by 10%
x_range = (x_min - (x_max - x_min) * 0.25, x_max + (x_max - x_min) * 0.25)
y_range = (y_min - (y_max - y_min) * 0.25, y_max + (y_max - y_min) * 0.25)
z_range = (z_min - (z_max - z_min) * 0.25, z_max + (z_max - z_min) * 0.25)


# Create 3D scatter plot
fig = go.Figure()
all_trace = go.Scatter3d(x=x_coords, y=y_coords, z=z_coords, mode='markers', hoverinfo='none')
fig.add_trace(all_trace)
fig.update_traces(marker=dict(color='lightgrey', opacity=0.05))
selected_trace = go.Scatter3d(x=selected_x_coords, y=selected_y_coords, z=selected_z_coords, mode='markers',
                              text=all_user_input_adj['description'], hoverinfo='text'
)
fig.add_trace(selected_trace)
selected_text = go.Scatter3d(x=text_x_coords, y=text_y_coords, z=text_z_coords, mode='text', text=text_labels, hoverinfo='none')
fig.add_trace(selected_text)
fig.update_layout(
    title="",
    scene=dict(
        xaxis=dict(range=x_range, showbackground=False, showgrid=False, showline=False, zeroline=True, showticklabels=False, showspikes=False),
        yaxis=dict(range=y_range, showbackground=False, showgrid=False, showline=True, zeroline=True, showticklabels=False, showspikes=False),
        zaxis=dict(range=z_range, showbackground=False, showgrid=False, showline=True, zeroline=True, showticklabels=False, showspikes=False),
                xaxis_title='',
        yaxis_title='',
        zaxis_title='',
    ),
    margin=dict(
        b=0,
        t=0,
        l=0,
        r=0
    ),
    showlegend=False
)

fig.show()