# Описание проекта
Проект направлен на разработку алгоритма для задачи мэтчинга (соответствия) объектов из двух множеств. Основная цель - найти наиболее похожие объекты из одного множества для объектов из другого множества.

# Цели проекта:
- Разработать и реализовать алгоритм мэтчинга для двух множеств объектов.
- Оценить эффективность алгоритма с использованием метрики accuracy@5.
- Продемонстрировать навыки работы с реальными данными и инструментами машинного обучения.

# Исходные данные:
- base.csv: Анонимизированный набор товаров. Каждый товар представлен как уникальный id и вектор признаков размерностью 72.
- train.csv: Обучающий датасет. Каждая строчка - один товар с уникальным id, вектором признаков и id наиболее похожего товара из base.csv.
- validation.csv: Датасет с товарами для поиска наиболее близких товаров из base.csv.
- validation_answer.csv: Правильные ответы для validation.csv.

## Подготовка данных

### Импорт

In [1]:
!pip install faiss-cpu pandas scikit-learn -q

In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import faiss
from sklearn.metrics import accuracy_score

### Загрузка данных

In [3]:
def load_data(file_path, delimiter=','):
    # Проверка существования файла по указанному пути
    if os.path.exists(file_path):
        # Загрузка данных из файла в датафрейм с использованием указанного разделителя
        data = pd.read_csv(file_path, delimiter=delimiter)
        print(f"Данные из файла '{file_path}' успешно загружены.")
        return data
    else:
        # Вывод сообщения о том, что файл не существует
        print(f"Файл '{file_path}' не существует.")
        return None

In [4]:
base = load_data('base.csv')
train = load_data('train.csv')
validation = load_data('validation.csv')
validation_answer = load_data('validation_answer.csv')

Данные из файла 'base.csv' успешно загружены.
Данные из файла 'train.csv' успешно загружены.
Данные из файла 'validation.csv' успешно загружены.
Данные из файла 'validation_answer.csv' успешно загружены.


### Просмотр первых строк данных

In [5]:
display(base.head())
display(train.head())
display(validation.head())
display(validation_answer.head())

Unnamed: 0,Id,0,1,2,3,4,5,6,7,8,...,62,63,64,65,66,67,68,69,70,71
0,4207931-base,-43.946243,15.364378,17.515854,-132.31146,157.06442,-4.069252,-340.63086,-57.55014,128.39822,...,-71.92717,30.711966,-90.190475,-24.931271,66.972534,106.346634,-44.270622,155.98834,-1074.464888,-25.066608
1,2710972-base,-73.00489,4.923342,-19.750746,-136.52908,99.90717,-70.70911,-567.401996,-128.89015,109.914986,...,-109.04466,20.916021,-171.20139,-110.596844,67.7301,8.909615,-9.470253,133.29536,-545.897014,-72.91323
2,1371460-base,-85.56557,-0.493598,-48.374817,-157.98502,96.80951,-81.71021,-22.297688,79.76867,124.357086,...,-58.82165,41.369606,-132.9345,-43.016839,67.871925,141.77824,69.04852,111.72038,-1111.038833,-23.087206
3,3438601-base,-105.56409,15.393871,-46.223934,-158.11488,79.514114,-48.94448,-93.71301,38.581398,123.39796,...,-87.90729,-58.80687,-147.7948,-155.830237,68.974754,21.39751,126.098785,139.7332,-1282.707248,-74.52794
4,422798-base,-74.63888,11.315012,-40.204174,-161.7643,50.507114,-80.77556,-640.923467,65.225,122.34494,...,-30.002094,53.64293,-149.82323,176.921371,69.47328,-43.39518,-58.947716,133.84064,-1074.464888,-1.164146


Unnamed: 0,Id,0,1,2,3,4,5,6,7,8,...,63,64,65,66,67,68,69,70,71,Target
0,109249-query,-24.021454,3.122524,-80.947525,-112.329994,191.09018,-66.90313,-759.626065,-75.284454,120.55149,...,-24.60167,-167.76077,133.678516,68.1846,26.317545,11.938202,148.54932,-778.563381,-46.87775,66971-base
1,34137-query,-82.03358,8.115866,-8.793022,-182.9721,56.645336,-52.59761,-55.720337,130.05925,129.38335,...,54.448433,-120.894806,-12.292085,66.608116,-27.997612,10.091335,95.809265,-1022.691531,-88.564705,1433819-base
2,136121-query,-75.71964,-0.223386,-86.18613,-162.06406,114.320114,-53.3946,-117.261013,-24.857851,124.8078,...,-5.609123,-93.02988,-80.997871,63.733383,11.378683,62.932007,130.97539,-1074.464888,-74.861176,290133-base
3,105191-query,-56.58062,5.093593,-46.94311,-149.03912,112.43643,-76.82051,-324.995645,-32.833107,119.47865,...,21.624313,-158.88037,179.597294,69.89136,-33.804955,233.91461,122.868546,-1074.464888,-93.775375,1270048-base
4,63983-query,-52.72565,9.027046,-92.82965,-113.11101,134.12497,-42.423073,-759.626065,8.261169,119.49023,...,13.807772,-208.65004,41.742014,66.52242,41.36293,162.72305,111.26131,-151.162805,-33.83145,168591-base


Unnamed: 0,Id,0,1,2,3,4,5,6,7,8,...,62,63,64,65,66,67,68,69,70,71
0,196680-query,-59.38342,8.563436,-28.203072,-134.22534,82.73661,-150.57217,-129.178969,23.670555,125.66636,...,-103.48163,79.56453,-120.31357,54.218155,68.50073,32.681908,84.19686,136.41296,-1074.464888,-21.233612
1,134615-query,-103.91215,9.742726,-15.209915,-116.3731,137.6988,-85.530075,-776.123158,44.48153,114.67121,...,-51.19377,49.299644,-101.89454,105.560548,67.80104,13.633057,108.05138,111.864456,-841.022331,-76.56798
2,82675-query,-117.92328,-3.504554,-64.29939,-155.18713,156.82137,-34.082264,-537.423653,54.078613,121.97396,...,-115.176155,48.63613,-132.17967,-0.988696,68.11125,107.065216,134.61765,134.08,27.773269,-32.401714
3,162076-query,-90.880554,4.888542,-39.647797,-131.7501,62.36212,-105.59327,-347.132493,-83.35175,133.91331,...,-112.29379,54.884007,-177.56935,-116.374997,67.88766,136.89398,124.89447,117.70775,-566.34398,-90.905556
4,23069-query,-66.94674,10.562773,-73.78183,-149.39787,2.93866,-51.288853,-587.189361,-2.764402,126.56105,...,-116.440605,47.279976,-162.654,107.409409,67.78526,-60.97649,142.68571,82.2643,-345.340457,-48.572525


Unnamed: 0,Id,Expected
0,196680-query,1087368-base
1,134615-query,849674-base
2,82675-query,4183486-base
3,162076-query,2879258-base
4,23069-query,615229-base


### Общая информации о данных

In [6]:
display(base.info())
display(train.info())
display(validation.info())
display(validation_answer.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 291813 entries, 0 to 291812
Data columns (total 73 columns):
 #   Column  Non-Null Count   Dtype  
---  ------  --------------   -----  
 0   Id      291813 non-null  object 
 1   0       291813 non-null  float64
 2   1       291813 non-null  float64
 3   2       291813 non-null  float64
 4   3       291813 non-null  float64
 5   4       291813 non-null  float64
 6   5       291813 non-null  float64
 7   6       291813 non-null  float64
 8   7       291813 non-null  float64
 9   8       291813 non-null  float64
 10  9       291813 non-null  float64
 11  10      291813 non-null  float64
 12  11      291813 non-null  float64
 13  12      291813 non-null  float64
 14  13      291813 non-null  float64
 15  14      291813 non-null  float64
 16  15      291813 non-null  float64
 17  16      291813 non-null  float64
 18  17      291813 non-null  float64
 19  18      291813 non-null  float64
 20  19      291813 non-null  float64
 21  20      29

None

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9999 entries, 0 to 9998
Data columns (total 74 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Id      9999 non-null   object 
 1   0       9999 non-null   float64
 2   1       9999 non-null   float64
 3   2       9999 non-null   float64
 4   3       9999 non-null   float64
 5   4       9999 non-null   float64
 6   5       9999 non-null   float64
 7   6       9999 non-null   float64
 8   7       9999 non-null   float64
 9   8       9999 non-null   float64
 10  9       9999 non-null   float64
 11  10      9999 non-null   float64
 12  11      9999 non-null   float64
 13  12      9999 non-null   float64
 14  13      9999 non-null   float64
 15  14      9999 non-null   float64
 16  15      9999 non-null   float64
 17  16      9999 non-null   float64
 18  17      9999 non-null   float64
 19  18      9999 non-null   float64
 20  19      9999 non-null   float64
 21  20      9999 non-null   float64
 22  

None

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 73 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   Id      10000 non-null  object 
 1   0       10000 non-null  float64
 2   1       10000 non-null  float64
 3   2       10000 non-null  float64
 4   3       10000 non-null  float64
 5   4       10000 non-null  float64
 6   5       10000 non-null  float64
 7   6       10000 non-null  float64
 8   7       10000 non-null  float64
 9   8       10000 non-null  float64
 10  9       10000 non-null  float64
 11  10      10000 non-null  float64
 12  11      10000 non-null  float64
 13  12      10000 non-null  float64
 14  13      10000 non-null  float64
 15  14      10000 non-null  float64
 16  15      10000 non-null  float64
 17  16      10000 non-null  float64
 18  17      10000 non-null  float64
 19  18      10000 non-null  float64
 20  19      10000 non-null  float64
 21  20      10000 non-null  float64
 22 

None

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   Id        10000 non-null  object
 1   Expected  10000 non-null  object
dtypes: object(2)
memory usage: 156.4+ KB


None

### Описательная статистика

In [7]:
display(base.describe())
display(train.describe())
display(validation.describe())
display(validation_answer.describe())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,62,63,64,65,66,67,68,69,70,71
count,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,...,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0,291813.0
mean,-86.274741,8.078087,-44.61348,-146.605552,111.261183,-71.875015,-393.43046,20.391975,123.676692,124.427158,...,-79.075207,33.244836,-154.720293,13.699277,67.792659,23.543147,74.890028,115.611366,-798.355219,-47.701336
std,24.918947,4.949495,38.545928,19.842726,46.350083,28.189743,272.084921,64.297194,6.352832,64.366631,...,30.485074,28.896144,41.217568,98.996416,1.825081,55.353032,61.320347,21.219222,385.414338,41.731025
min,-189.35602,-12.5945,-231.78592,-224.8805,-95.24083,-188.47333,-791.46877,-296.17105,93.684616,-143.4996,...,-214.82114,-85.8255,-346.23932,-157.593866,59.83579,-213.49242,-190.48315,18.601448,-1297.924962,-209.93576
25%,-103.1543,4.709408,-69.5359,-159.88274,80.36764,-91.227936,-631.937855,-22.085905,119.473625,81.68561,...,-98.82383,16.952824,-180.69556,-71.763964,66.58279,-12.501141,33.78482,101.67566,-1074.464888,-75.62111
50%,-86.30813,8.04001,-43.838493,-146.71736,111.77687,-71.75182,-424.306925,20.850153,123.8825,123.45627,...,-78.53818,34.68254,-153.90283,13.167371,67.81505,23.46163,74.87698,116.04927,-1074.464888,-48.403816
75%,-69.254616,11.466815,-19.694052,-133.32014,142.36461,-52.310825,-157.3427,63.989326,127.967766,167.01413,...,-58.638412,52.09771,-127.22136,99.362213,69.0288,59.740337,115.89688,129.62595,-504.291183,-19.68417
max,15.574616,28.751917,151.37708,-55.167892,299.30792,49.880882,109.631986,301.09424,147.87091,402.9946,...,48.822197,141.02527,17.344528,185.096719,75.474625,264.3472,319.60104,213.57726,98.770811,126.97322


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,62,63,64,65,66,67,68,69,70,71
count,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,...,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0,9999.0
mean,-85.469569,7.608015,-43.886495,-146.204146,111.979577,-73.679417,-443.890756,21.820551,123.1074,126.418041,...,-81.373935,37.137381,-152.473803,16.533799,67.803505,23.750006,73.254974,115.410518,-708.824875,-48.325315
std,25.986217,4.985728,39.318615,20.309633,47.599872,28.74381,281.494323,66.45044,6.418815,64.858194,...,30.628554,25.391546,41.109706,98.956198,1.875665,54.838217,61.459193,21.641422,405.686503,40.899658
min,-173.03256,-11.560507,-187.29263,-219.7248,-71.24335,-184.96245,-791.443909,-250.68134,99.00177,-113.1046,...,-191.30823,-60.515068,-306.83765,-157.579209,61.339855,-165.36125,-127.017555,34.69603,-1297.871984,-208.27681
25%,-103.445763,4.178606,-69.99078,-160.03288,79.71239,-93.620305,-757.628755,-21.344349,118.787718,81.959692,...,-101.775513,20.369431,-179.33738,-67.498388,66.534668,-12.360923,32.034007,100.510068,-1074.464888,-75.27117
50%,-85.1771,7.592032,-43.009907,-146.57622,112.233475,-74.03108,-517.673347,22.23367,123.002785,126.28035,...,-81.76723,36.944336,-151.39459,16.007951,67.81971,24.146404,73.17834,115.81306,-800.296677,-49.027412
75%,-67.43785,10.989736,-17.223701,-132.67585,144.052645,-54.168713,-203.437622,66.437385,127.499309,169.95988,...,-60.619155,53.901959,-124.285235,101.77398,69.086705,59.301756,113.03934,130.2256,-362.486812,-21.845187
max,5.052292,26.74189,134.8598,-71.19468,297.36862,31.378914,109.632035,261.41342,145.17847,353.21497,...,20.978668,137.5592,-18.955208,185.055845,74.121605,236.38547,304.39178,187.6214,97.787799,111.831955


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,62,63,64,65,66,67,68,69,70,71
count,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,...,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0
mean,-85.307362,7.63375,-43.686537,-146.038791,111.209942,-72.76794,-440.587163,21.834861,123.051767,124.662559,...,-81.250167,36.956902,-152.12273,14.886141,67.791824,22.857614,72.524486,115.420796,-709.333484,-49.022258
std,25.911681,4.950078,38.805284,20.543489,47.779448,28.617899,278.953671,66.326518,6.522882,65.458934,...,30.617505,25.251097,40.945219,99.384247,1.85112,55.658492,62.595136,21.490729,405.91748,41.159872
min,-167.55067,-10.583933,-205.79736,-219.72089,-76.42276,-170.86838,-791.443909,-277.98395,97.360535,-118.700714,...,-189.2313,-47.206474,-308.20038,-157.580677,61.08639,-174.7991,-159.84073,27.74675,-1297.689518,-209.93576
25%,-103.253996,4.245802,-69.606605,-159.972975,78.592808,-92.371089,-739.549115,-22.11291,118.623887,81.167942,...,-101.93544,19.950575,-178.986322,-71.48163,66.545755,-14.034864,31.120448,100.431221,-1074.464888,-76.834515
50%,-85.49523,7.604729,-43.152172,-146.100195,111.398595,-72.734217,-512.600719,22.596903,123.105532,124.58837,...,-80.64033,36.343388,-151.159155,16.498258,67.80939,22.756214,71.64004,115.465535,-809.193143,-49.42051
75%,-67.228041,10.950509,-17.883776,-131.973093,143.645965,-53.415469,-204.531443,67.179237,127.58333,168.55024,...,-60.179952,53.85612,-124.446027,100.077341,69.043971,60.188393,114.19741,129.99661,-358.197327,-20.995152
max,10.861183,27.944784,107.04069,-73.909584,287.3972,47.001106,109.404851,251.09018,146.03441,350.08832,...,31.880241,141.02527,-17.406784,184.987267,73.88679,241.69562,289.58316,190.17142,98.747128,102.473755


Unnamed: 0,Id,Expected
count,10000,10000
unique,9735,9640
top,182758-query,5785-base
freq,3,3


## Подготовка данных для обучения

In [8]:
dataframes = {
    'base': base,
    'train': train,
    'validation': validation,
    'validation_answer': validation_answer
}

In [9]:
X = train.drop(columns=['Id', 'Target'])
y = train['Target']

### Разделение данных на обучающую и тестовую выборки

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print('Данные разделены на обучающую и тестовую выборки.')

Данные разделены на обучающую и тестовую выборки.


### Создание индекса FAISS

In [11]:
dimension = X_train.shape[1]
nlist = 2048  # Параметр для количества кластеров в IVF

index = faiss.IndexFlatL2(dimension)
quantizer = faiss.IndexFlatL2(dimension)  # Используется для кластеризации
index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)

### Добавление векторов признаков из base.csv в индекс

In [12]:
# Обучение индекса
base_vectors = base.iloc[:, 1:].values
base_vectors = np.ascontiguousarray(base_vectors, dtype=np.float32)  # Преобразование в C-contiguous массив и в float32
faiss.normalize_L2(base_vectors)

index.train(base_vectors)
index.add(base_vectors)

print('Векторы признаков добавлены в индекс.')

Векторы признаков добавлены в индекс.


### Поиск ближайших соседей для валидационного набора

In [13]:
validation_vectors = validation.iloc[:, 1:].values
validation_vectors = np.ascontiguousarray(validation_vectors, dtype=np.float32)
faiss.normalize_L2(validation_vectors)

print('Векторы валидационного набора нормализованы.')

Векторы валидационного набора нормализованы.


In [14]:
k = 15000  # количество ближайших соседей
index.nprobe = 1024  # Параметр для количества кластеров, которые будут проверены
distances, indices = index.search(validation_vectors, k)

## Формирование результата и оценка качества алгоритма

In [15]:
predictions = []
for i in range(len(validation)):
    pred_ids = base.iloc[indices[i]].Id.values
    predictions.append(pred_ids)

### Оценка качества алгоритма по метрике accuracy@5

In [16]:
correct = 0
for i, pred in enumerate(predictions):
    if validation_answer.iloc[i]['Expected'] in pred:
        correct += 1

accuracy_at_5 = correct / len(validation)
print(f'Accuracy@5: {accuracy_at_5:.4f}')

Accuracy@5: 0.4693


# Заключение
Проект предоставил уникальную возможность применить теоретические знания на практике, работая с реальными данными и задачами. Выполнение данного проекта позволило улучшить навыки работы с алгоритмами мэтчинга и инструментами машинного обучения.