In [5]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests


model_id = "llava-hf/llava-v1.6-mistral-7b-hf"
processor = LlavaNextProcessor.from_pretrained(model_id)

model = LlavaNextForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    # low_cpu_mem_usage=True,
    # load_in_4bit=True
)

model.to("cuda:0")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlavaNextForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn

In [4]:
import pandas as pd
import fsspec

file_url = 'https://huggingface.co/datasets/huggan/wikiart/resolve/main/data/train-00000-of-00072.parquet'
fs = fsspec.filesystem('https')

# Читаем данные прямо с URL, не загружая на локальный диск
with fs.open(file_url) as f:
    df = pd.read_parquet(f)

df.drop(df[df.genre == 10].index, inplace=True)
print(len(df))

import requests
import json

# URL файла JSON с метаданными
file_url = 'https://huggingface.co/datasets/huggan/wikiart/resolve/main/dataset_infos.json'

# Загрузка файла JSON с помощью requests
response = requests.get(file_url)

# Преобразуем содержимое в формат JSON
dataset_info = json.loads(response.content)

# Выводим содержимое JSON для анализа
# print(json.dumps(dataset_info, indent=4))

genres_info = dataset_info['huggan--wikiart']['features']['genre']
genres = genres_info['names']
artists =  dataset_info['huggan--wikiart']['features']['artist']['names']
artists_dict = {i: author for i, author in enumerate(artists)}
styles =  dataset_info['huggan--wikiart']['features']['style']['names']
styles_dict = {i: s for i, s in enumerate(styles)}

print(artists_dict)
print(styles_dict)
genres_str = '\n'.join([f'{i}: {g}' for i, g in enumerate(genres)])
genres_str

964
{0: 'Unknown Artist', 1: 'boris-kustodiev', 2: 'camille-pissarro', 3: 'childe-hassam', 4: 'claude-monet', 5: 'edgar-degas', 6: 'eugene-boudin', 7: 'gustave-dore', 8: 'ilya-repin', 9: 'ivan-aivazovsky', 10: 'ivan-shishkin', 11: 'john-singer-sargent', 12: 'marc-chagall', 13: 'martiros-saryan', 14: 'nicholas-roerich', 15: 'pablo-picasso', 16: 'paul-cezanne', 17: 'pierre-auguste-renoir', 18: 'pyotr-konchalovsky', 19: 'raphael-kirchner', 20: 'rembrandt', 21: 'salvador-dali', 22: 'vincent-van-gogh', 23: 'hieronymus-bosch', 24: 'leonardo-da-vinci', 25: 'albrecht-durer', 26: 'edouard-cortes', 27: 'sam-francis', 28: 'juan-gris', 29: 'lucas-cranach-the-elder', 30: 'paul-gauguin', 31: 'konstantin-makovsky', 32: 'egon-schiele', 33: 'thomas-eakins', 34: 'gustave-moreau', 35: 'francisco-goya', 36: 'edvard-munch', 37: 'henri-matisse', 38: 'fra-angelico', 39: 'maxime-maufra', 40: 'jan-matejko', 41: 'mstislav-dobuzhinsky', 42: 'alfred-sisley', 43: 'mary-cassatt', 44: 'gustave-loiseau', 45: 'fernand

'0: abstract_painting\n1: cityscape\n2: genre_painting\n3: illustration\n4: landscape\n5: nude_painting\n6: portrait\n7: religious_painting\n8: sketch_and_study\n9: still_life\n10: Unknown Genre'

In [7]:
model.generation_config.pad_token_id = model.generation_config.eos_token_id

In [6]:
prompt_template = """The task is to classify an image into one of the following 11 genres:\n{genres_str}.\nYou are provided with the name of the artist and the painting style of the image.\nArtist: {artist}\nStyle: {style}\nBased on this information, determine the correct genre. Choose the genre from the list by providing only the corresponding number. Say ONLY the number of the genre and do not say anything else."""

In [23]:
torch.cuda.empty_cache()

In [27]:
prompt_template = """There are 11 possible genres:\n{genres_str}.\nWhat is the genre of painting by {artist} in style {style}? Choose one from the given list. If you do not know the genre exactly, choose '10: Unknown Genre'. Say only the number of genre, do not output anything else."""

In [8]:
from PIL import Image
from io import BytesIO


# prompt_template = """There are 11 possible genres:\n{genres_str}.\nWhat is the genre of painting by {artist} in style {style}? Choose from the given list. If you do not know the genre exactly, choose '10: Unknown Genre'. Say only the number of genre, do not output anything else."""
predictions = []
correct_predictions = 0  # Счётчик правильных предсказаний
total_predictions = len(df)
true_genres = df['genre'].tolist()

# model.generation_config.pad_token_id = tokenizer.pad_token_id

for idx, row in df.iterrows():
    if idx % 100 == 0:
        print('-'*50)
        print(idx)
        print('-'*50)

    image_bytes = row['image']['bytes']
    image_stream = BytesIO(image_bytes)
    image = Image.open(image_stream)
    image = image.resize((512, 512))

    artist = artists_dict[row['artist']]
    style = styles_dict[row['style']]

    prompt = prompt_template.format(genres_str=genres_str, artist=artist, style=style)

    conversation = [
      {

        "role": "user",
        "content": [
            {"type": "text", "text": prompt},
            {"type": "image"},
          ],
      },
    ]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")

    # autoregressively complete prompt
    outputs = model.generate(**inputs, max_new_tokens=2)
    outputs_text = processor.decode(outputs[0], skip_special_tokens=True)
    
    # print(outputs_text)
    try:
        pred = int(eval(outputs_text[-1]))
    except (ValueError, SyntaxError):
        print("Ошибка: результат не является числом")
        pred = 12

    true_genre = row['genre']
    if pred == true_genre:
      correct_predictions += 1

    print(pred, true_genre)

    predictions.append(pred)



Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.


--------------------------------------------------
0
--------------------------------------------------


Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


4 4
7 7
6 6
4 2
4 2
4 4
2 7
6 6
4 4
4 2
4 4
6 6
6 6
9 9
6 8
7 3
4 4
3 3
6 8
4 4
4 1
4 8
6 6
4 1
4 4
4 8
5 5
6 6
6 7
9 9
6 6
4 4
0 7
5 5
6 6
9 9
6 6
4 8
4 4
9 9
5 2
6 6
4 4
8 8
6 2
6 6
3 2
8 8
4 4
6 6
1 1
4 4
5 5
4 4
6 6
3 2
7 3
4 2
3 3
6 6
0 2
4 7
6 6
4 4
8 2
5 7
6 6
6 6
4 1
4 4
4 4
4 1
Ошибка: результат не является числом
12 4
4 2
5 3
4 4
4 1
4 1
4 2
1 1
5 5
4 2
4 4
6 6
4 4
4 4
9 2
4 4
0 9
--------------------------------------------------
100
--------------------------------------------------
2 2
4 4
Ошибка: результат не является числом
12 9
4 4
4 8
5 5
4 4
0 9
4 2
4 4
4 4
6 2
2 2
4 4
4 4
4 8
4 1
4 7
3 3
4 1
3 3
4 4
6 6
4 4
6 8
4 8
0 2
7 3
6 6
4 4
6 2
4 4
0 2
4 4
6 2
4 4
3 3
9 9
4 4
3 8
4 1
4 4
4 8
6 6
4 2
4 4
4 2
4 4
4 4
4 3
1 1
1 1
4 4
4 1
4 4
4 4
4 4
8 8
2 7
6 6
3 7
4 8
Ошибка: результат не является числом
12 2
2 7
4 4
6 8
4 4
5 5
5 9
9 5
Ошибка: результат не является числом
12 2
4 3
2 8
2 2
5 2
4 2
4 5
6 2
0 6
6 2
8 8
6 6
5 5
4 4
4 8
4 1
6 6
--------------------------------------

In [25]:
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 50.35%


второй промпт

In [29]:
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 48.50%


промпт первый, без неизвестных жанров

In [12]:
for p in predictions:
    if p == 12:
        print(p)

12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12


In [13]:
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 58.30%
