In [None]:
# 4/26目標:カスタムトレーニングループorAPIでの学習の実装(subclassingAPIでのkerasAplicationが正しく動くか確認)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications.resnet50 import ResNet50
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import time

In [None]:
random_seed = 2021
np.random.seed(random_seed)
learning_rate = 1e-5
alpha = 0.2
batch = 32
num_batches_per_step = 100
epochs = 10
ckpt_dir = "./train/"
print(3*batch*num_batches_per_step*epochs)

In [None]:
train_df = pd.read_csv("../input/shopee-product-matching/train.csv")
group_names = train_df.label_group.unique()

In [None]:
def triplet_gen(df, group_names, dir_prefix, batch=1):
    # dfからトリプレットを抽出
    while True:
        data = np.empty((batch*3,224,224,3), np.float32)
        idx = 0
        for _ in range(batch):
            # グループ名決定
            two_groups = np.random.choice(group_names, size=2, replace=False, p=None)
            np.random.shuffle(two_groups)

            # positive, negativeのグループ決定
            p_group, n_group = two_groups
            p_images = df[df.label_group == p_group].image.values
            n_images = df[df.label_group == n_group].image.values
            np.random.shuffle(p_images)
            np.random.shuffle(n_images)

            # anchor, positive, negativeの画像名取得
            anchor, positive = np.random.choice(p_images, size=2, replace=False, p=None)
            negative = np.random.choice(n_images, size=1, replace=False, p=None)[0]

            # 画像読み込み
            # Pillowのリサイズ: img.resize((width, height), method)
            anchor_img = np.asarray(Image.open(dir_prefix.format(anchor)).resize((224,224)), dtype=np.float32)
            positive_img = np.asarray(Image.open(dir_prefix.format(positive)).resize((224,224)), dtype=np.float32)
            negative_img = np.asarray(Image.open(dir_prefix.format(negative)).resize((224,224)), dtype=np.float32)
            
            data[idx] = anchor_img / 255.0
            idx += 1
            data[idx] = positive_img / 255.0
            idx += 1
            data[idx] = negative_img / 255.0
            idx += 1
        # yield p_group, n_group, anchor_img, positive_img, negative_img
        yield data

In [None]:
# Subclassing API
class MetricLearningModel(tf.keras.Model):
    def __init__(self, alpha=0.2, batch=1):
        # モデル定義(keras.application使用)
        super(MetricLearningModel, self).__init__()
        self.resnet50 = ResNet50(include_top=False, weights='imagenet', input_shape=None)
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(128)
        self.anchors_idx = tf.range(start=0, limit=batch*3, delta=3)
        self.positives_idx = tf.range(start=1, limit=batch*3+1, delta=3)
        self.negatives_idx = tf.range(start=2, limit=batch*3+2, delta=3)
        self.margin = alpha
        
    def call(self, x):
        # 特徴量の抽出
        x = self.resnet50(x)
        flatten = self.flatten(x)
        return self.dense(flatten)
    
    def loss(self, inputs):
        # トリプレットロスの計算
        # バッチごとのトリプレットロスの平均を返す
        anchors = tf.gather(inputs, self.anchors_idx, axis=0)
        positives = tf.gather(inputs, self.positives_idx, axis=0)
        negatives = tf.gather(inputs, self.negatives_idx, axis=0)
        p_dists = tf.norm(anchors-positives, axis=1)
        n_dists = tf.norm(anchors-negatives, axis=1)
        loss_matrix = tf.maximum(p_dists-n_dists+self.margin, 0.0)
        return tf.reduce_mean(loss_matrix)
        
        
# # Functional API
# def create_model():
#     image_tensor = layers.Input((224,224,3))
#     model = ResNet50(include_top=False, weights='imagenet', input_tensor=image_tensor)
#     x = model.layers[-1].output
#     flatten = layers.Flatten()(x)
#     vector = layers.Dense(512)(flatten)
#     return tf.keras.Model(image_tensor, vector)

In [None]:
@tf.function
def train_step(x):
    with tf.GradientTape() as tape:
        vectors = model.call(x)
        loss_value = model.loss(vectors)        
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss_value

def train(model, data_gen, epochs):
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
        
    for epoch in range(1, epochs+1):
        print("Start of epoch {}".format(epoch))
        start_time = time.perf_counter()
        
        for step_in_epoch in range(1, num_batches_per_step+1):
            total_step = num_batches_per_step*(epoch-1) + step_in_epoch
            loss_value = train_step(data_gen.__next__())
            ckpt.step.assign_add(1)
            
            if total_step % 50 == 0:
                print("Training loss (for one batch) at step {}: {:.5f}".format(total_step, loss_value))
                print("Seen so far: {} samples".format((total_step+1)*3*batch))
                
            if int(ckpt.step) % 10 == 0:
                save_path = manager.save()
                print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
                
        print("Time taken: {:.2f}s".format(time.perf_counter()-start_time))

In [None]:
# モデルのビルド、サマリーの表示
model = MetricLearningModel(alpha=alpha, batch=batch)
model.build(input_shape=(batch*3, 224, 224, 3))
# 事前学習済みのベースモデルのパラメータ更新をオフ
model.layers[0].trainable = False
model.summary()
# ジェネレータ, オプティマイザ等の定義
gen = triplet_gen(train_df, group_names, '../input/shopee-product-matching/train_images/{}', batch)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model)
manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=5)

In [None]:
# 学習
train(model, gen, epochs)