### Импорты

In [4]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from pathlib import Path
import numpy as np
from annoy import AnnoyIndex
import pickle
import os
from PIL import Image
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

### Image Encoder

In [5]:
class ImageEncoder:
    def __init__(self):
        print("Loading ConvNeXT model...")
        self.model = models.convnext_large(pretrained=True)
        self.model = nn.Sequential(*list(self.model.children())[:-1])
        self.model.eval()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        
        self.transform = transforms.Compose([
            transforms.Resize(236, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
        print(f"Model loaded successfully on {self.device}!")

    def get_embedding(self, image_path):
        try:
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image).unsqueeze(0)
            image = image.to(self.device)
            
            with torch.no_grad():
                embedding = self.model(image)
                
            embedding = embedding.squeeze().cpu().numpy()
            embedding = embedding / np.linalg.norm(embedding)
            return embedding
        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            return None

In [6]:
encoder = ImageEncoder()

Loading ConvNeXT model...




Model loaded successfully on cuda!


### Обработка датасета

In [7]:
def process_dataset(dataset_path, encoder, save_dir="./data"):
    dataset_path = Path(dataset_path)
    
    embeddings_dict = {}
    file_mapping = {}
    class_mapping = {}
    reverse_class_mapping = {} 
    class_stats = defaultdict(int)
    idx = 0
    print("Processing dataset...")
    
    for class_dir in tqdm(list(dataset_path.iterdir())):
        if class_dir.is_dir():
            class_name = class_dir.name
            for image_file in class_dir.glob("*.*"):
                if image_file.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                    embedding = encoder.get_embedding(str(image_file))
                    if embedding is not None:
                        embeddings_dict[idx] = embedding
                        file_mapping[idx] = image_file.name
                        class_mapping[idx] = class_name
                        reverse_class_mapping[image_file.name] = class_name
                        class_stats[class_name] += 1
                        idx += 1
    
    # print(f"\nTotal images processed: {idx}")
    # print("\nClass distribution:")
    # for class_name, count in class_stats.items():
    #     print(f"{class_name}: {count} images")
    
    os.makedirs(save_dir, exist_ok=True)
    with open(f"{save_dir}/processed_data.pkl", "wb") as f:
        pickle.dump({
            'embeddings': embeddings_dict,
            'file_mapping': file_mapping,
            'class_mapping': class_mapping,
            'reverse_class_mapping': reverse_class_mapping,
            'class_stats': dict(class_stats)
        }, f)
    
    return embeddings_dict, file_mapping, class_mapping, reverse_class_mapping

In [8]:
dataset_path = "/home/moo/Downloads/train_dataset_train_data_rkn/train_data_rkn/dataset"  # Путь к датасету
embeddings_dict, file_mapping, class_mapping, reverse_class_mapping = process_dataset(dataset_path, encoder)

Processing dataset...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 105/105 [08:18<00:00,  4.75s/it]


### Индекс

In [9]:
def build_index(embeddings_dict, save_dir="./data"):
    first_embedding = next(iter(embeddings_dict.values()))
    embedding_dim = len(first_embedding)
    
    index = AnnoyIndex(embedding_dim, 'angular')
    
    for idx, embedding in embeddings_dict.items():
        index.add_item(idx, embedding)
    
    print("Building index with 100 trees...")
    index.build(100)
    index.save(f"{save_dir}/image_index.ann")
    
    return index

index = build_index(embeddings_dict)

Building index with 100 trees...


### Поиск

In [54]:
def find_similar(query_image_path, index, encoder, file_mapping, class_mapping, n_results=10):
    query_embedding = encoder.get_embedding(query_image_path)
    if query_embedding is None:
        return []
    
    n_candidates = 30
    similar_idx, distances = index.get_nns_by_vector(
        query_embedding, n_candidates, include_distances=True)
    
    candidate_classes = []
    candidate_files = []
    candidate_distances = []
    query_image_name = os.path.basename(query_image_path)
    
    for idx, dist in zip(similar_idx, distances):
        candidate_file = file_mapping[idx]
        if candidate_file == query_image_name:
            continue
        candidate_classes.append(class_mapping[idx])
        candidate_files.append(candidate_file)
        candidate_distances.append(dist)
    
    if not candidate_files:
        return []
    
    distances = candidate_distances
    # Проверка 1
    similarity_threshold = 0.1
    if distances[0] < similarity_threshold:
        assigned_class = candidate_classes[0]
        class_images = [f for f, cls in zip(candidate_files, candidate_classes) if cls == assigned_class]
        similar_images = class_images[:n_results]
    else:
        # Проверка 2
        class_counts = {}
        for cls in candidate_classes[:10]:
            class_counts[cls] = class_counts.get(cls, 0) + 1
        
        most_common_class = max(class_counts, key=class_counts.get)
        if class_counts[most_common_class] >= 6:
            assigned_class = most_common_class
            class_images = [f for f, cls in zip(candidate_files, candidate_classes) if cls == assigned_class]
            similar_images = class_images[:n_results]
        else:
            similar_images = candidate_files[:n_results]

    if len(similar_images) < n_results:
        additional_images = [f for f in candidate_files if f not in similar_images]
        similar_images.extend(additional_images[:n_results - len(similar_images)])
    
    return similar_images

### Генерация предикта

In [55]:
def generate_recommendations(test_dir, index, encoder, file_mapping, class_mapping, output_file="submission.csv"):

    test_path = Path(test_dir)
    results = []
    print("Generating recommendations for test images...")
    for image_file in list(test_path.rglob("*.*")):
        if image_file.suffix.lower() in ['.jpg', '.jpeg', '.png']:
            query_image_path = str(image_file)
            similar_images = find_similar(
                query_image_path, index, encoder, file_mapping, class_mapping, n_results=10)
            
            if similar_images:
                recs = ",".join(similar_images)
                results.append({
                    'image': image_file.name,
                    'recs': f'"{recs}"'
                })
    
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    print(f"\nRecommendations saved to {output_file}")
    
    return results


In [56]:
test_dir = "/home/moo/Downloads/train_dataset_train_data_rkn/train_data_rkn/test"

results = generate_recommendations(
    test_dir=test_dir,
    index=index,
    encoder=encoder,
    file_mapping=file_mapping,
    class_mapping=class_mapping,
    output_file="submission.csv"
)

Generating recommendations for test images...

Recommendations saved to submission.csv
