In [2]:
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import tensorflow.keras as keras
import tensorflow as tf

from sklearn.model_selection import GroupShuffleSplit
from sklearn.utils import shuffle

from tensorflow.keras.preprocessing.image import load_img, img_to_array

root_dir = "shopee-product-matching/"

# Попытка разобраться с дата-генераторами

Возможно, не тот выходной вектор??

In [5]:
# вспомогательный
class DataGenerator(keras.utils.Sequence):
    def __init__(self, df, batch_size=32, input_size=(224, 224, 3), shuffle=True):
        self.df = df.copy()
        self.input_size = input_size
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def on_epoch_end(self):
        if self.shuffle:
            self.df = shuffle(self.df)
    
    def load_image(self, path):
        img = load_img(path, target_size=self.input_size)
        input_arr = img_to_array(img)
        input_arr = input_arr / 255.0
        return input_arr 
    
    def load_triplet(self, df, ind):
        value = df.iloc[ind,:]
        similar = self.df[self.df["label_group"] == value["label_group"]][self.df["image"] != value["image"]]
        different = self.df[self.df["label_group"] != value["label_group"]]
        
        anchor = self.load_image(value["image"])
        positive = self.load_image(np.random.choice(similar["image"]))
        negative = self.load_image(np.random.choice(different["image"]))
            
        return(anchor, positive, negative)
    
    def __get_data(self, batches):
        # Generates data containing batch_size samples

        X_batch, y_batch = [], []
        for i in range(len(batches)):
            a, p, n = self.load_triplet(batches, i)
            X_batch.append(a)
            y_batch.append([p, n])
        return X_batch, y_batch
    
    def __getitem__(self, index):
        batches = self.df[index * self.batch_size:(index + 1) * self.batch_size]
        X, y = self.__get_data(batches)        
        return X, y
    
    def __len__(self):
        return len(self.df) // self.batch_size

In [4]:
# сам хэндлер
class DatasetHandler:    
    def __init__(self, root_dir, batch_size=32, image_size=(300, 300, 3), shuffle=True):
        all_train_df = pd.read_csv(root_dir + "train.csv")
        all_train_df['image'] = root_dir + 'train_images/' + all_train_df['image']       
        
        splitter = GroupShuffleSplit(n_splits=2, test_size=0.2, random_state=42)
        train_idx, val_idx = next(splitter.split(all_train_df, groups=all_train_df.label_group))
        self.train_df = all_train_df.iloc[train_idx]
        self.val_df = all_train_df.iloc[val_idx]
        
        self.image_size = image_size
        
        self.train_generator = DataGenerator(self.train_df, batch_size, image_size, shuffle)
        self.val_generator = DataGenerator(self.val_df, batch_size, image_size, shuffle)
        
    def get_train_gen(self):
        return self.train_generator
    
    def get_val_gen(self):
        return self.val_generator

In [6]:
dh = DatasetHandler(root_dir, 5)

In [7]:
tg = dh.get_train_gen()
vg = dh.get_val_gen()

In [8]:
def visualize_triplet(a, p, n):
    fig = plt.figure(figsize=(9, 3))
    axs = fig.subplots(1, 3)
        
    def show(ax, image):
        ax.imshow(image)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
            
    show(axs[0], a)
    show(axs[1], p)
    show(axs[2], n)

# Старое

In [11]:
class DatasetHandler:    
    def __init__(self, root_dir, image_size=(300, 300)):
        all_train_df = pd.read_csv(root_dir + "train.csv")
        all_train_df['image'] = root_dir + 'train_images/' + all_train_df['image']       
        
        splitter = GroupShuffleSplit(n_splits=2, test_size=0.2, random_state=42)
        train_idx, val_idx = next(splitter.split(all_train_df, groups=all_train_df.label_group))
        self.train_df = all_train_df.iloc[train_idx]
        self.val_df = all_train_df.iloc[val_idx]
        
        self.image_size = image_size
        self.images = self.load_images(all_train_df, True)
        
    def load_image(self, path):
        img = load_img(path, target_size=self.image_size)
        input_arr = img_to_array(img)
        input_arr = input_arr / 255.0
        return input_arr
    
    def load_images(self, df, verbose = False):
        img_dict = {}
        for i in range(len(df)):
            img_dict[df["image"].iloc[i]] = self.load_image(df["image"].iloc[i])
            if verbose:
                print(f"Loaded img {i+1} out of {len(df)}")
        return img_dict
        
    def load_triplet(self, df, ind, visualize=False):
        value = df.iloc[ind,:]
        similar = df[(df["label_group"] == value["label_group"]) & (df["image"] != value["image"])]
        different = df[df["label_group"] != value["label_group"]]
        
        anchor = self.images[value["image"]]
        positive = self.images[np.random.choice(similar["image"])]
        negative = self.images[np.random.choice(different["image"])]
        
        if visualize:
            self.visualize_triplet(anchor, positive, negative)
            
        return(anchor, positive, negative)  
    
    def visualize_triplet(self, a, p, n):
        fig = plt.figure(figsize=(9, 3))
        axs = fig.subplots(1, 3)
        
        def show(ax, image):
            ax.imshow(image)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            
        show(axs[0], a)
        show(axs[1], p)
        show(axs[2], n)
        
    def get_train_triplets(self):
        self.train_df = shuffle(self.train_df)
        triplets = []
        for i in range(len(self.train_df)):
            a, p, n = self.load_triplet(self.train_df, i)
            triplets.append([a, p, n])
            print(f"Loaded triplet {i+1} out of {len(self.train_df)}")
        return triplets

Запуталась в том, какой должен быть вывод.

Изображения загружаются вечность.

In [None]:
dh = DatasetHandler(root_dir)

In [None]:
data = dh.get_train_triplets()

In [None]:
visualize_triplet(data[0])