In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import euclidean_distances
import ast

In [2]:
df = pd.read_csv("omniart-paintings-filtered-clean.csv")

In [3]:
# Drop missing palette info
df = df.dropna(subset=['color_pallete', 'palette_count'])

# Parse strings as lists
df['color_pallete'] = df['color_pallete'].apply(ast.literal_eval)
df['palette_count'] = df['palette_count'].apply(ast.literal_eval)

In [4]:
def hex_to_rgb(hex_color):
    hex_color = hex_color.lstrip("#")
    return [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]

In [5]:
# Build color-feature vectors We’ll use a flattened RGB vector, weighted by frequency.

def get_color_vector(colors, counts, top_n=10):
    total = sum(counts[:top_n])
    vec = []
    for color, count in zip(colors[:top_n], counts[:top_n]):
        weight = count / total  # normalized frequency
        rgb = hex_to_rgb(color)
        vec.extend([c * weight for c in rgb])  # weighted RGB
    # Pad with zeros if less than top_n
    while len(vec) < top_n * 3:
        vec.extend([0, 0, 0])
    return vec

In [6]:
df['color_vector'] = df.apply(lambda row: get_color_vector(row['color_pallete'], row['palette_count']), axis=1)
X = np.array(df['color_vector'].tolist())

In [11]:
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
X_embedded = tsne.fit_transform(X)

In [None]:
# Clustering (KMeans)
kmeans = KMeans(n_clusters=100, random_state=42)
df['cluster'] = kmeans.fit_predict(X)


In [None]:
plt.figure(figsize=(12, 8))
sns.scatterplot(
    x=X_embedded[:, 0],
    y=X_embedded[:, 1],
    hue=df['cluster'],
    palette='tab10',
    alpha=0.7
)
plt.title("Clustering of Paintings by Color Pattern")
plt.xlabel("t-SNE Dim 1")
plt.ylabel("t-SNE Dim 2")
plt.legend(title="Cluster")
plt.show()

In [None]:
from scipy.spatial.distance import pdist, squareform

# Top N unique colors across dataset
all_colors = df['color_pallete'].explode().value_counts().head(30).index.tolist()
rgb_colors = np.array([hex_to_rgb(c) for c in all_colors])
dists = squareform(pdist(rgb_colors))

# Plot color proximity
plt.figure(figsize=(10, 10))
for i, color in enumerate(all_colors):
    x, y = rgb_colors[i][0], rgb_colors[i][1]
    plt.scatter(x, y, color=color, s=300)
    plt.text(x+2, y, color, fontsize=8)
    for j in range(len(all_colors)):
        if i != j and dists[i, j] < 80:
            x2, y2 = rgb_colors[j][0], rgb_colors[j][1]
            plt.plot([x, x2], [y, y2], color='gray', alpha=0.3)

plt.title("Color Proximity Graph (RGB Space)")
plt.xlabel("Red")
plt.ylabel("Green")
plt.grid(True)
plt.show()