In [1]:
import utils
import loss
import batchgen
import models
import numpy as np
from tensorflow import keras
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

import random 
from itertools import permutations

In [2]:
# class TripletGenerator(keras.utils.Sequence):
#         def __init__(self, X, Y, batch_size = 64):
#             self.batch_size = batch_size
#             self.anchor = X[0]
#             self.positive = X[1]
#             self.negative = X[2]
#             self.labels = Y

#         def __len__(self):
#             return int(len(self.labels) / self.batch_size)

#         def __getitem__(self, i):
#             ret = []
#             for i in range(self.batch_size):
#                 apn = [self.anchor[i],
#                        self.positive[i],
#                        self.negative[i]]
#                 ret.append(apn)
#             ret = np.array(ret)
#             anchor, positive, negative = ret[:, 0], ret[:, 1], ret[:, 2]
#             dummy = np.zeros((self.batch_size, 1))
#             return [anchor, positive, negative], dummy

class TripletGenerator(keras.utils.Sequence):
        def __init__(self, X, Y, ap_pairs= 10, an_pairs = 10, batch_size = 64):
            self.batch_size = batch_size
            self.ap_pairs = ap_pairs
            self.an_pairs = an_pairs
            # self.anchor = X[0]
            # self.positive = X[1]
            # self.negative = X[2]
            self.images = X
            self.labels = Y
            self.unique_labels = np.unique(self.labels)
            self.triplet_index = self.__generate_triplet_index()
            np.random.shuffle(self.triplet_index)

        def __generate_triplet_index(self):
            triplet_index = [] # (anchor_id, positive_id, negative_id) 
            for class_id in self.unique_labels:
                same_class_idx = list(np.where((self.labels == class_id))[0])
                diff_class_idx = list(np.where(self.labels != class_id)[0])
                same_class_perms = list(permutations(same_class_idx,2))
                ap_idx = np.array(random.sample(same_class_perms, k=min(self.ap_pairs, len(same_class_perms)))) #Generating Anchor-Positive pairs
                if len(ap_idx) < 2:
                    continue
                anchor_idx = ap_idx[:, 0]
                pos_idx = ap_idx[:, 1]
                neg_idx = np.array(random.sample(diff_class_idx, k=self.an_pairs))
                assert len(anchor_idx) == len(pos_idx)
                ap_len = min(self.ap_pairs, len(same_class_perms))
                neg_len = self.an_pairs
                for j in range(ap_len):
                    aid, pid = anchor_idx[j], pos_idx[j]
                    for k in range(neg_len):
                        nid = neg_idx[k]
                        triplet_index.append([aid, pid, nid])
                
            return np.array(triplet_index)

        def __len__(self):
            return int(len(self.labels) / self.batch_size)

        def on_epoch_end(self):
            np.random.shuffle(self.triplet_index)

        def __getitem__(self, i):
            ret = []
            for j in range(i, i + self.batch_size):
                aid, pid, nid = self.triplet_index[j]
                anchor = self.images[aid]
                positive = self.images[pid]
                negative = self.images[nid]
                ret.append([anchor, positive, negative])
            ret = np.array(ret)
            anchors, positives, negatives = ret[:, 0], ret[:, 1], ret[:, 2]
            dummy = np.zeros((self.batch_size, 1))
            return [anchors, positives, negatives], dummy


In [5]:
NUM_PAIRS = 150
INPUT_SHAPE = (400, 300, 3)
FEATURE_SIZE = 64
BATCH_SIZE = 32

In [4]:
X, Y, index = utils.load_data("data/labels/lcwaikiki_labels.csv", resize = (300, 400), limit = None)

100%|██████████| 9995/9995 [00:27<00:00, 364.83it/s]


In [6]:
X.shape

(9987, 400, 300, 3)

In [7]:
# Split train test
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.2, random_state = 42)

In [8]:
tripletgen_train = TripletGenerator(X_train, Y_train, ap_pairs = NUM_PAIRS, an_pairs = NUM_PAIRS, batch_size = BATCH_SIZE)
tripletgen_test = TripletGenerator(X_test, Y_test, ap_pairs = NUM_PAIRS, an_pairs = NUM_PAIRS, batch_size = BATCH_SIZE)

In [9]:
# Prepare network
anchor_input = keras.layers.Input(INPUT_SHAPE, name='anchor_input')
positive_input = keras.layers.Input(INPUT_SHAPE, name='positive_input')
negative_input = keras.layers.Input(INPUT_SHAPE, name='negative_input')

# Shared embedding layer for positive and negative items
Shared_DNN = models.minixception(INPUT_SHAPE, feature_size = FEATURE_SIZE)

# Individual outputs
encoded_anchor = Shared_DNN(anchor_input)
encoded_positive = Shared_DNN(positive_input)
encoded_negative = Shared_DNN(negative_input)

# Merged output layer
merged_vector = keras.layers.concatenate([encoded_anchor, encoded_positive, encoded_negative], axis=-1, name='merged_layer')

# Define optimizer
adam_optim = keras.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999)

# Setup and compile model
model = keras.models.Model(inputs=[anchor_input,positive_input, negative_input], outputs=merged_vector)
model.compile(loss=loss.triplet_loss, optimizer=adam_optim)

In [10]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
anchor_input (InputLayer)       [(None, 400, 300, 3) 0                                            
__________________________________________________________________________________________________
positive_input (InputLayer)     [(None, 400, 300, 3) 0                                            
__________________________________________________________________________________________________
negative_input (InputLayer)     [(None, 400, 300, 3) 0                                            
__________________________________________________________________________________________________
model (Model)                   (None, 64)           61224       anchor_input[0][0]               
                                                                 positive_input[0][0]       

In [11]:
# Train model
model.fit_generator(
    tripletgen_train,
    validation_data=tripletgen_test, 
    epochs=3, verbose = 1)

Epoch 1/3
  1/249 [..............................] - ETA: 1:58:03 - loss: 9.4297

KeyboardInterrupt: 