# デュアルネットワーク学習用のnotebook

実験設定を共通化するため
- Normal の場合N=512(256*2)
- Improved の場合N=300(150*2)

とする

## ライブラリのインポート

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import random
from numba import jit
from tqdm import tqdm
import os
import pickle
%matplotlib inline

場合によってはGPUの指定が必要かもしれない

In [2]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES']='0'

keras関連

In [3]:
import keras
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [4]:
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input,Layer,Lambda
from keras.layers import Flatten,BatchNormalization
from keras.layers import Dense,Dropout
from keras.layers import concatenate
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras import backend as K

In [5]:
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.keras.backend.set_session(tf.Session(config=config))

In [6]:
# import tensorflow as tf
# from keras.backend.tensorflow_backend import set_session
# config = tf.ConfigProto(
#     gpu_options=tf.GPUOptions(
#         visible_device_list="0", # specify GPU number
#         allow_growth=True
#     )
# )
# set_session(tf.Session(config=config))

## ハイパーパラメータ設定
- 保存先のディレクトリやファイル名に影響するため慎重に。
- Improved Triplet Lossにおいては、`alpha > beta`を満たす必要がある(元論文)

### 入力画像の情報

In [7]:
imheight = 128
imwidth = 128
channels = 3
category = 'Pants'

### 出力次元
- Dualなネットワークではconcatするため出力次元はdense_num*2

In [8]:
dense_num = 150
vec_length=dense_num*2

### Triplet Loss設定
- `ALPHA=D(a,p)とD(a,n)の相対距離を抑制`, `BETA=D(a,p)の絶対距離を制御`

In [9]:
ALPHA=0.1
BETA=0.05

## VGG16の読み込み

In [10]:
from keras.applications.vgg16 import VGG16
#include_top=false => Dense不要
base_model = VGG16(include_top=False, weights='imagenet', input_tensor=Input(shape=(imwidth, imheight, channels)), input_shape=None) 

Instructions for updating:
Colocations handled automatically by placer.


学習しないように重みを固定

In [11]:
for layer in base_model.layers[:15]:
    layer.trainable=False

## ネットワーク構造を定義

- 浅いネットワーク(shallow_model)を作成

元々(32,(4,4))だったが，奇数フィルタの方がいいらしい

In [12]:
inputs = Input(shape=(imwidth, imheight, channels))
conv1 = Conv2D(32, (4,4) , padding='same', activation='relu')(inputs)
pool1 = MaxPooling2D(pool_size=(2,2), strides=None, padding='valid')(conv1)
conv2 = Conv2D(32, (4,4) , padding='same', activation='relu')(pool1)
pool2 = MaxPooling2D(pool_size=(2,2), strides=None, padding='valid')(conv2)
flatten = Flatten()(pool2) 
dense_layer = Dense(dense_num, activation='relu')(flatten)
norm_layer = Lambda(lambda  x: K.l2_normalize(x, axis=1), name='norm_layer1')(dense_layer)
shallow_model=Model(inputs=inputs,outputs=norm_layer) 

- shallow_modelと、VGGを通したdeepなモデルと結合しモデル全体を作成する関数

In [13]:
def create_embNet():
    shallow_inputs = Input(shape=(imwidth, imheight, channels))
    x = base_model.output
    conv1 = Conv2D(filters=32, kernel_size=(3,3) , padding='same', activation='relu')(x)
    conv2 = Conv2D(filters=32, kernel_size=(3,3) , padding='same', activation='relu')(conv1)
    flatten = Flatten()(conv2) 
    dense_layer = Dense(dense_num, activation='relu')(flatten)
    norm_layer = Lambda(lambda  x: K.l2_normalize(x, axis=1), name='norm_layer')(dense_layer)
    # inputに対してshallow_modelのoutputも用意
    x1 = norm_layer
    x2 = shallow_model(shallow_inputs)
    out = concatenate([norm_layer,x2])
    return Model(inputs=[base_model.input,shallow_inputs],outputs=out)

- inputを定義する
- create embNet()の中でInputを定義すると明示的に3つの入力が分けられない

In [14]:
# define three Inputs
a_in = Input(shape = (imheight, imwidth, channels), name='anchor_input')
p_in = Input(shape = (imheight, imwidth, channels), name='positive_input')
n_in = Input(shape = (imheight, imwidth, channels), name='negative_input')

### これは不明

In [15]:
sa_in = Input(shape = (imheight, imwidth, channels), name='sanchor_input')

- **後に埋め込み用のモデルとして利用するため**ベクトル化までの部分を別で定義しておく

In [16]:
con_embNet = create_embNet()
shop_embNet = create_embNet()

In [17]:
shallow_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 128, 128, 32)      1568      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 64, 32)        16416     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 32, 32, 32)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 32768)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 150)               4915350   
__________

- 埋め込み用のベクトルもあらかじめ用意する

In [18]:
con_embNet.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 128, 128, 64) 1792        input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 128, 128, 64) 36928       block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_pool (MaxPooling2D)      (None, 64, 64, 64)   0           block1_conv2[0][0]               
__________________________________________________________________________________________________
block2_con

In [19]:
a_emb = shop_embNet([a_in,a_in])
p_emb = con_embNet([p_in,p_in])
n_emb = con_embNet([n_in,n_in])

## Triplet Loss
- 通常のTriplet Lossを用いる場合はこちらを使う。
- `Loss=max[D(a,p)-D(a-n)+margin,0] where D(A,B)=||A-B||_2^2`

In [20]:
class TripletLossLayer(Layer):
    def __init__(self, alpha, **kwargs):
        self.alpha = alpha
        super(TripletLossLayer, self).__init__(**kwargs)
    
    def triplet_loss(self, inputs):
        a, p, n = inputs
        p_dist = K.sum(K.square(a-p), axis=-1)
        n_dist = K.sum(K.square(a-n), axis=-1)
        return K.sum(K.maximum(p_dist - n_dist + self.alpha, 0), axis=0)
    
    def call(self, inputs):
        loss = self.triplet_loss(inputs)
        self.add_loss(loss)
        return loss
    
    def get_config(self):
        config = {'alpha': self.alpha}
        base_config = super(TripletLossLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

### Lossレイヤの定義とモデルのコンパイル

In [21]:
# Layer that computes the triplet loss from anchor, positive and negative embedding vectors
triplet_loss_layer = TripletLossLayer(alpha=ALPHA, name='triplet_loss_layer')([a_emb, p_emb, n_emb])

# Model that can be trained with anchor, positive negative images
tripletNet = Model([a_in, p_in, n_in], triplet_loss_layer)
tripletNet.compile(loss=None, optimizer='adam')

In [22]:
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

# SVG(model_to_dot(con_embNet).create(prog='dot', format='svg'))

## Improved Triplet Lossの場合

### Improved Triplet Loss
- `Loss=[D(a,p)-D(a,n)+ALPHA]+[D(a,p)-BETA]`
- Positiveを短くする方向に制御する

https://qiita.com/tancoro/items/35d0925de74f21bfff14#improved-triplet-loss

<img src="./readme_imgs/improved.PNG" width=30% align=left><br>

- Improved Triplet Loss用にレイヤを改変

In [23]:
class TripletLossLayer(Layer):
    def __init__(self, alpha, beta, **kwargs):
        self.alpha = alpha
        self.beta = beta
        super(TripletLossLayer, self).__init__(**kwargs)

    def triplet_loss(self, inputs):
        a, p, n = inputs
        p_dist = K.sum(K.square(a-p), axis=-1)
        n_dist = K.sum(K.square(a-n), axis=-1)
        pn_dist = K.sum(K.square(p-n), axis=-1)
        return K.sum(K.maximum((p_dist - n_dist + self.alpha), 0) + K.maximum((p_dist - self.beta), 0), axis=0)
    
    def call(self, inputs):
        loss = self.triplet_loss(inputs)
        self.add_loss(loss)
        return loss
    
    def get_config(self):
        config = {'alpha': self.alpha}
        base_config = super(TripletLossLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

### モデルの定義とコンパイル
- ハイパーパラメータとしてBETAが増えていることに注意

In [24]:
# Layer that computes the triplet loss from anchor, positive and negative embedding vectors
triplet_loss_layer = TripletLossLayer(alpha=ALPHA, beta=BETA, name='triplet_loss_layer')([a_emb, p_emb, n_emb])

# Model that can be trained with anchor, positive negative images
tripletNet = Model([a_in, p_in, n_in], triplet_loss_layer)
tripletNet.compile(loss=None, optimizer='adam')

## データの用意

- `T_Shirt_all/`(クロップ済み画像が商品id別に保存されたディレクトリ)

In [15]:
BASE_PATH = './dataset/crop_img/img/TOPS/Coat/'
category = 'Coat'
ids = sorted([x for x in os.listdir(BASE_PATH)])

In [26]:
import shutil
def rmdir(ids,PATH):
    for id_ in ids:
        files = sorted([BASE_PATH+id_+'/'+x for x in os.listdir(BASE_PATH+id_)])
        con = sorted([x for x in files if 'comsumer' in x])
        shop = sorted([x for x in files if 'shop' in x ])
        if len(con)==0 or len(shop)==0:
            shutil.rmtree(PATH+id_) 

### TripletのPATHを返す関数
- 入力:`ids=商品id群`, `BASE_PATH=商品群ディレクトリへのPATH`
- `[consumer_ancのパス, shop_posのパス, shop_negのパス]`のような組を返す

In [4]:
import itertools
import random

def get_triplets(ids,BASE_PATH):
    triplets=[]
    con_length = 0
    shop_length = 0
    for id_ in tqdm(ids):
        files = sorted([BASE_PATH+id_+'/'+x for x in os.listdir(BASE_PATH+id_)])
        con = sorted([x for x in files if 'comsumer' in x])
        shop = sorted([x for x in files if 'shop' in x ])
        combs = list(itertools.product(tuple(con),tuple(shop)))
        con_length += len(con)
        shop_length += len(shop)
        for comb in combs:
            comb = list(comb)
            neg_id = random.choice([x for x in ids if x != id_])
            neg_file = random.choice([BASE_PATH+neg_id+'/'+x for x in os.listdir(BASE_PATH+neg_id) if 'shop' in x])
            comb.append(neg_file)
            triplets.append(comb)
    print("con_length = {}".format(con_length))
    print("shop_length = {}".format(shop_length))
    return triplets

### 商品idの単位でTrain/Testを分割する
- Seed値(random_state)は固定
- random.choice()はブートストラップサンプリングのため、`train_test_split()`を利用
- idsオブジェクトは以降使わないためここで削除

In [16]:
from sklearn.cross_validation import train_test_split
train_ids,test_ids=train_test_split(ids,test_size=0.33,random_state=0)
del ids

In [17]:
triplets_train_PATHs = get_triplets(train_ids,BASE_PATH)

100%|██████████| 1629/1629 [00:10<00:00, 153.64it/s]

con_length = 9511
shop_length = 2347





In [18]:
print(len(triplets_train_PATHs))

17151


In [20]:
len(train_ids)

1629

In [19]:
len(test_ids)

803

In [30]:
length = 0
for id_ in train_ids:
    length += len(os.listdir(BASE_PATH+id_))
length

12270

In [31]:
# train_ids

- テスト時に参照するためTestデータの情報をpickleで保存

In [32]:
f = open('./pickle/{}/test_ids.pickle'.format(category), 'wb')
pickle.dump(test_ids, f)

### Triplet作成用関数
- エポックごとにTripletの組み合わせをランダムに変更するためpickle保存ができない
- エポックごとに変えるのは普通なのか怪しい
- **Tripletを200個くらい予め作成しpickle保存しておけば今後回すのが楽になるのでは。**

In [33]:
from PIL import Image
def get_np_triplets(triplet_PATHs):
    triplets = []
    for triplet in tqdm(triplet_PATHs):
        anc_img = Image.open(triplet[0]).convert('RGB')
        pos_img = Image.open(triplet[1]).convert('RGB')
        neg_img = Image.open(triplet[2]).convert('RGB')

        anc_img = np.array(anc_img.resize((128,128)))/255. #resize to (128,128,3)
        pos_img = np.array(pos_img.resize((128,128)))/255.    
        neg_img = np.array(neg_img.resize((128,128)))/255.    

        tri = [anc_img,pos_img,neg_img]
        triplets.append(np.array(tri))

    triplets = np.array(triplets)
    return triplets

## 学習する

- 各エポックでtestデータを用いて`N-top acc`を出したい
- epochの外でtrain,testに分割するパターンなので常にtestのidは同じ
- `train_triplet`のnegativeが毎回ランダムになるので偏らないメリットがある->**pickle保存済みのtripletを用いれば学習を効率化できる**
- `model.fit()`は`epochs=1`で行う
- `model_history=[]`に各エポックにおけるメトリクス(loss/accuracy等)をappendすることで後から推移を確認できる。

In [34]:
# model_dir = './model/{}/improved_tripletloss/a{}b{}'.format(category,ALPHA,BETA)
# model_dir = './model/{}/Dual_normal/a{}'.format(category,ALPHA)
model_dir = './model/{}/Dual_improved/a{}b{}'.format(category,ALPHA,BETA)
os.listdir(model_dir)

['300', '.ipynb_checkpoints']

In [35]:
epochs = 100

In [36]:
model_history = []
for epoch in range(epochs):
    print('epoch %s'% epoch)
    if epoch % 5 == 0:
        if epoch != 0: del triplets
        triplets_train_PATHs = get_triplets(train_ids,BASE_PATH)
        triplets = get_np_triplets(triplets_train_PATHs)
        del triplets_train_PATHs
    # fit
    hist = tripletNet.fit([triplets[:,0],triplets[:,1],triplets[:,2]], epochs=1, batch_size=50)
    model_history.append(hist.history)
    f = open(model_dir+'/{}/history{}.txt'.format(vec_length,epoch),'wb')
    pickle.dump(model_history, f)
    # 使い終わったので削除
##    del triplets
    if (epoch+1) % 5 == 0:
        shop_embNet.save(model_dir+'/{}/shop_e{}.h5'.format(vec_length,epoch))
        con_embNet.save(model_dir+'/{}/con_e{}.h5'.format(vec_length,epoch))
        
# 学習のhistoryを保存
f = open(model_dir+'/{}/history.txt'.format(vec_length),'wb')
pickle.dump(model_history, f)

  2%|▏         | 34/1815 [00:00<00:05, 337.12it/s]

epoch 0


100%|██████████| 1815/1815 [00:07<00:00, 254.71it/s]
100%|██████████| 16583/16583 [01:31<00:00, 181.18it/s]


Instructions for updating:
Use tf.cast instead.
Epoch 1/1
epoch 1
Epoch 1/1
epoch 2
Epoch 1/1
epoch 3
Epoch 1/1
epoch 4
Epoch 1/1


  2%|▏         | 34/1815 [00:00<00:05, 331.57it/s]

epoch 5


100%|██████████| 1815/1815 [00:07<00:00, 238.79it/s]
100%|██████████| 16583/16583 [01:26<00:00, 192.81it/s]


Epoch 1/1
epoch 6
Epoch 1/1
epoch 7
Epoch 1/1
epoch 8
Epoch 1/1
epoch 9
Epoch 1/1


  2%|▏         | 33/1815 [00:00<00:05, 329.65it/s]

epoch 10


100%|██████████| 1815/1815 [00:07<00:00, 263.03it/s]
100%|██████████| 16583/16583 [01:27<00:00, 190.45it/s]


Epoch 1/1
epoch 11
Epoch 1/1
epoch 12
Epoch 1/1
epoch 13
Epoch 1/1
epoch 14
Epoch 1/1


  2%|▏         | 32/1815 [00:00<00:05, 315.89it/s]

epoch 15


100%|██████████| 1815/1815 [00:07<00:00, 232.71it/s]
100%|██████████| 16583/16583 [01:20<00:00, 194.79it/s]


Epoch 1/1
epoch 16
Epoch 1/1
epoch 17
Epoch 1/1
epoch 18
Epoch 1/1
epoch 19
Epoch 1/1


  2%|▏         | 32/1815 [00:00<00:05, 315.81it/s]

epoch 20


100%|██████████| 1815/1815 [00:08<00:00, 223.47it/s]
100%|██████████| 16583/16583 [01:32<00:00, 170.81it/s]


Epoch 1/1
epoch 21
Epoch 1/1
epoch 22
Epoch 1/1
epoch 23
Epoch 1/1
epoch 24
Epoch 1/1


  2%|▏         | 33/1815 [00:00<00:05, 328.63it/s]

epoch 25


100%|██████████| 1815/1815 [00:07<00:00, 255.21it/s]
100%|██████████| 16583/16583 [01:38<00:00, 169.14it/s]


Epoch 1/1
epoch 26
Epoch 1/1
epoch 27
Epoch 1/1
epoch 28
Epoch 1/1
epoch 29
Epoch 1/1


  2%|▏         | 36/1815 [00:00<00:05, 351.87it/s]

epoch 30


100%|██████████| 1815/1815 [00:07<00:00, 237.72it/s]
100%|██████████| 16583/16583 [01:36<00:00, 172.30it/s]


Epoch 1/1
epoch 31
Epoch 1/1
epoch 32
Epoch 1/1
epoch 33
Epoch 1/1
epoch 34
Epoch 1/1


  2%|▏         | 31/1815 [00:00<00:05, 307.84it/s]

epoch 35


100%|██████████| 1815/1815 [00:07<00:00, 254.15it/s]
100%|██████████| 16583/16583 [01:39<00:00, 166.79it/s]


Epoch 1/1
epoch 36
Epoch 1/1
epoch 37
Epoch 1/1
epoch 38
Epoch 1/1
epoch 39
Epoch 1/1


  2%|▏         | 29/1815 [00:00<00:06, 288.84it/s]

epoch 40


100%|██████████| 1815/1815 [00:07<00:00, 241.19it/s]
100%|██████████| 16583/16583 [01:33<00:00, 177.33it/s]


Epoch 1/1
epoch 41
Epoch 1/1
epoch 42
Epoch 1/1
epoch 43
Epoch 1/1
epoch 44
Epoch 1/1


  2%|▏         | 37/1815 [00:00<00:04, 369.32it/s]

epoch 45


100%|██████████| 1815/1815 [00:07<00:00, 247.71it/s]
100%|██████████| 16583/16583 [01:37<00:00, 170.16it/s]


Epoch 1/1
epoch 46
Epoch 1/1
epoch 47
Epoch 1/1
epoch 48
Epoch 1/1
epoch 49
Epoch 1/1


  2%|▏         | 36/1815 [00:00<00:04, 356.80it/s]

epoch 50


100%|██████████| 1815/1815 [00:07<00:00, 238.10it/s]
100%|██████████| 16583/16583 [01:39<00:00, 166.83it/s]


Epoch 1/1
epoch 51
Epoch 1/1
epoch 52
Epoch 1/1
epoch 53
Epoch 1/1
epoch 54
Epoch 1/1


  2%|▏         | 34/1815 [00:00<00:05, 338.87it/s]

epoch 55


100%|██████████| 1815/1815 [00:06<00:00, 259.53it/s]
100%|██████████| 16583/16583 [01:37<00:00, 169.99it/s]


Epoch 1/1
epoch 56
Epoch 1/1
epoch 57
Epoch 1/1
epoch 58
Epoch 1/1
epoch 59
Epoch 1/1


  2%|▏         | 33/1815 [00:00<00:05, 325.25it/s]

epoch 60


100%|██████████| 1815/1815 [00:07<00:00, 255.37it/s]
100%|██████████| 16583/16583 [01:34<00:00, 166.68it/s]


Epoch 1/1
epoch 61
Epoch 1/1
epoch 62
Epoch 1/1
epoch 63
Epoch 1/1
epoch 64
Epoch 1/1


  2%|▏         | 39/1815 [00:00<00:04, 367.70it/s]

epoch 65


100%|██████████| 1815/1815 [00:07<00:00, 249.46it/s]
100%|██████████| 16583/16583 [01:34<00:00, 175.87it/s]


Epoch 1/1
epoch 66
Epoch 1/1
epoch 67
Epoch 1/1
epoch 68
Epoch 1/1
epoch 69
Epoch 1/1


  2%|▏         | 39/1815 [00:00<00:04, 383.14it/s]

epoch 70


100%|██████████| 1815/1815 [00:07<00:00, 249.82it/s]
100%|██████████| 16583/16583 [01:33<00:00, 177.26it/s]


Epoch 1/1
epoch 71
Epoch 1/1
epoch 72
Epoch 1/1
epoch 73
Epoch 1/1
epoch 74
Epoch 1/1


  2%|▏         | 30/1815 [00:00<00:05, 298.72it/s]

epoch 75


100%|██████████| 1815/1815 [00:08<00:00, 215.98it/s]
100%|██████████| 16583/16583 [01:34<00:00, 175.33it/s]


Epoch 1/1
epoch 76
Epoch 1/1
epoch 77
Epoch 1/1
epoch 78
Epoch 1/1
epoch 79
Epoch 1/1


  2%|▏         | 36/1815 [00:00<00:04, 357.94it/s]

epoch 80


100%|██████████| 1815/1815 [00:07<00:00, 259.14it/s]
100%|██████████| 16583/16583 [01:38<00:00, 168.56it/s]


Epoch 1/1
epoch 81
Epoch 1/1
epoch 82
Epoch 1/1
epoch 83
Epoch 1/1
epoch 84
Epoch 1/1


  2%|▏         | 42/1815 [00:00<00:04, 410.01it/s]

epoch 85


100%|██████████| 1815/1815 [00:07<00:00, 257.15it/s]
100%|██████████| 16583/16583 [01:37<00:00, 170.67it/s]


Epoch 1/1
epoch 86
Epoch 1/1
epoch 87
Epoch 1/1
epoch 88
Epoch 1/1
epoch 89
Epoch 1/1


  2%|▏         | 34/1815 [00:00<00:05, 338.71it/s]

epoch 90


100%|██████████| 1815/1815 [00:07<00:00, 234.55it/s]
100%|██████████| 16583/16583 [01:38<00:00, 168.42it/s]


Epoch 1/1
epoch 91
Epoch 1/1
epoch 92
Epoch 1/1
epoch 93
Epoch 1/1
epoch 94
Epoch 1/1


  2%|▏         | 39/1815 [00:00<00:04, 367.92it/s]

epoch 95


100%|██████████| 1815/1815 [00:06<00:00, 263.35it/s]
100%|██████████| 16583/16583 [01:36<00:00, 171.31it/s]


Epoch 1/1
epoch 96
Epoch 1/1
epoch 97
Epoch 1/1
epoch 98
Epoch 1/1
epoch 99
Epoch 1/1
