In [2]:
#!pip install renumics-spotlight datasets


In [153]:
import pandas as pd
from PIL import Image
import numpy as np
from tqdm import tqdm

def load_images_from_dataframe(dataframe, image_column, label_column=None, image_size=(224, 224)):
    images = []
    labels = []
    for index, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Loading Images"):
        image_path = row[image_column]
        label = None
        if label_column:
            label = row[label_column]
        # Load image
        image = Image.open(image_path)
        # Resize image
        image = image.resize(image_size)
        # Convert image to numpy array
        image = np.array(image)
        # Normalize if necessary
        # Append to list
        images.append(image)
        if label:
            labels.append(label)
    if label_column:
        return {"images": images, "labels": labels}
    else:
        return {"images": images, "artist_name": dataframe['artist_name'].values, "influenced_by": dataframe['influenced_by'].values, "path": dataframe['relative_path'].values}


# Example usage:
df = pd.read_pickle('DATA/Dataset/wikiart_full_combined_no_artist.pkl')
df = df[df['mode'] == 'val'].reset_index(drop=True)
df['relative_path'] = df['relative_path'].apply(lambda x: 'wikiart/' + x)
dataset = load_images_from_dataframe(df, image_column="relative_path")


Loading Images: 100%|██████████| 3942/3942 [01:05<00:00, 60.28it/s] 


In [159]:
from datasets import Dataset

ds = Dataset.from_dict(dataset)


In [164]:
ds['images']

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x11028b550>>
Traceback (most recent call last):
  File "/Users/traopia/miniconda3/envs/artsagenet_new/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


In [3]:
import datasets
# load dataset containing raw data (images and labels)
#ds = datasets.load_dataset("cifar10", split="test")


Downloading readme:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/120M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [156]:
# load model and define inference functions
import torch
import transformers

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")
model_name = "google/vit-base-patch16-224-in21k"
processor = transformers.ViTImageProcessor.from_pretrained(model_name)
cls_model = transformers.ViTForImageClassification.from_pretrained(model_name).to(
    device
)
fe_model = transformers.ViTModel.from_pretrained(model_name).to(device)


def infer(batch):
    images = [image.convert("RGB") for image in batch]
    inputs = processor(images=images, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = cls_model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
        embeddings = fe_model(**inputs).last_hidden_state[:, 0].cpu().numpy()
    return {"embedding": embeddings}


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [162]:
# enrich dataset with predictions and embeddings
ds_enrichments = ds.map(infer, input_columns="images", batched=True, batch_size=32)


Map:   0%|          | 0/3942 [00:00<?, ? examples/s]

AttributeError: 'list' object has no attribute 'convert'

In [65]:
ds_enrichments.features

{'embedding': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)}

In [66]:
from renumics import spotlight
ds_enriched = datasets.concatenate_datasets([ds, ds_enrichments], axis=1)
spotlight.show(ds_enriched, dtype={'embedding':spotlight.Embedding})


ValueError: Expected a list of Dataset objects or a list of IterableDataset objects, but element at position 0 is a module.

In [62]:
import pandas as pd
from datasets import Dataset

df = pd.read_pickle('DATA/Dataset/wikiart_full_combined_no_artist.pkl')
df.drop(columns=['title','additional_styles', 'artist_school', 'tags', 'influenced_by','image', 'relative_path', 'artist_attribution',
       'timeframe_estimation', 'tag_prediction', 'mode', 'date','concatenated_text'], inplace=True)
#df['influenced_by'] = df['influenced_by'].apply(lambda x: ' '.join(x))
df.image_features = df.image_features.apply(lambda x: x.numpy())
df.text_features = df.text_features.apply(lambda x: x.numpy())
df.image_text_features = df.image_text_features.apply(lambda x: x.numpy())


dataset = Dataset.from_pandas(df)




In [63]:
spotlight.show(dataset, dtype={'image_features':spotlight.Embedding})


VBox(children=(Label(value='Spotlight running on http://127.0.0.1:55674/'), HBox(children=(Button(description=…

In [137]:
#visualization with the trained model
from Triplet_Network import TripletResNet_features
import torch

df = pd.read_pickle('DATA/Dataset/wikiart_full_combined_no_artist.pkl')
all_artist_names = set(df['artist_name'])
df['influenced_by'] = df['influenced_by'].apply(lambda artists_list: [artist for artist in artists_list if artist in all_artist_names])
# df.drop(columns=['title','additional_styles', 'artist_school', 'tags', 'influenced_by','image', 'relative_path', 'artist_attribution',
#        'timeframe_estimation', 'tag_prediction', 'mode', 'date','concatenated_text'], inplace=True)
df.drop(columns = ['title'],inplace=True)
df['influenced_by'] = df['influenced_by'].apply(lambda x: ' '.join(x))
df = df[df['mode'] == 'val'].reset_index(drop=True)
feature = 'image_text_features'
model_path = 'trained_models/TripletResNet_image_text_features_posfaiss_negrandom_100_margin10/model.pth'
model = TripletResNet_features(df.loc[0,feature].shape[0])
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
df[f'trained_{feature}'] = df[feature].apply(lambda x: model.forward_once(x).detach())

In [138]:
import seaborn as sns
df.trained_image_text_features = df.trained_image_text_features.apply(lambda x: x.numpy())
df.image_features = df.image_features.apply(lambda x: x.numpy())
df.text_features = df.text_features.apply(lambda x: x.numpy())
df.image_text_features = df.image_text_features.apply(lambda x: x.numpy())


# unique_values = pd.Series(dataset['artist_name']).unique()

# color_palette = sns.color_palette('husl', len(unique_values))

# # Create a dictionary mapping each unique value to its corresponding color
# color_mapping_artist = dict(zip(unique_values, color_palette))

# # Create a function to map each artist to its corresponding color from the palette
# def map_artist_to_color(artist):
#     return color_mapping_artist[artist]

# # Apply this function to each list of artists for each observation to obtain a list of colors
# aggregate_colors = []
# for artists_list in df['influenced_by']:
#     colors_list = [map_artist_to_color(artist) for artist in artists_list]
#     aggregate_color = tuple(sum(color_component) for color_component in zip(*colors_list))
#     aggregate_colors.append(aggregate_color)
# df['colors'] = aggregate_colors
# import numpy as np
# df['colors'] = df.colors.apply(lambda x: np.mean(x))
dataset = Dataset.from_pandas(df)


In [139]:
spotlight.show(dataset, dtype={'image_text_features':spotlight.Embedding})


VBox(children=(Label(value='Spotlight running on http://127.0.0.1:54767/'), HBox(children=(Button(description=…

In [142]:
columns_to_visualize = {
    'image_text_features': spotlight.Embedding,
    'artist_name': spotlight.Category,  # Example for another column with text data

}

spotlight.show(dataset, dtype=columns_to_visualize)

VBox(children=(Label(value='Spotlight running on http://127.0.0.1:52066/'), HBox(children=(Button(description=…