In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os

In [2]:
tf.__version__

'2.0.0'

In [3]:
data_dir = 'PokemonData/'
train_dir = data_dir + 'train/'
test_dir = data_dir + 'test/'
validation_dir = data_dir + 'validation/'

In [4]:
train_data = os.listdir(train_dir)
test_data = os.listdir(test_dir)
validation_data = os.listdir(validation_dir)

In [5]:
# train_data.remove('.DS_Store')
validation_data

['Zapdos',
 'Kadabra',
 'Omanyte',
 'Shellder',
 'Bellsprout',
 'Eevee',
 'Jolteon',
 'Hypno',
 'Seel',
 'Zubat',
 'Graveler',
 'Magneton',
 'Abra',
 'Kingler',
 'Alakazam',
 'Clefable',
 'Gyarados',
 'Poliwag',
 'Rapidash',
 'Machamp',
 'Pinsir',
 'Muk',
 'Seaking',
 'Magikarp',
 'Goldeen',
 'Venusaur',
 'Flareon',
 'Mr.Mime',
 'Jigglypuff',
 'Doduo',
 'Weedle',
 'Vileplume',
 'Arcanine',
 'Tentacruel',
 'Gloom',
 'Charmeleon',
 'Articuno',
 'Sandshrew',
 'Spearow',
 'Marowak',
 'Clefairy',
 'Snorlax',
 'Scyther',
 'Primeape',
 'Diglett',
 'Onix',
 'Mankey',
 'Rattata',
 'Voltorb',
 'Gengar',
 'Nidoran-f',
 'Gastly',
 'Cloyster',
 'Weepinbell',
 'Dragonair',
 'Squirtle',
 'Pikachu',
 'Victreebel',
 'Charmander',
 'Staryu',
 'Venonat',
 'Vaporeon',
 'Ivysaur',
 'Krabby',
 'Drowzee',
 'Sandslash',
 'Kangaskhan',
 'Chansey',
 'Butterfree',
 'Starmie',
 'Magmar',
 'Beedrill',
 'Ninetales',
 'Magnemite',
 'Metapod',
 'Electrode',
 'Raichu',
 'Fearow',
 'Mewtwo',
 'Kabuto',
 'Pidgeotto',
 '

In [15]:
n_classes = len(train_data)

In [7]:
def get_data_cnt(data, data_dir):
    cnt = 0

    for pokemon in data:
        cnt += len(os.listdir(data_dir + pokemon))

    return cnt

In [8]:
print('train 이미지 개수', get_data_cnt(train_data, train_dir))
print('test 이미지 개수', get_data_cnt(test_data, test_dir))
print('validation 이미지 개수', get_data_cnt(validation_data, validation_dir))
print('포켓몬 종류', len(train_data))

train 이미지 개수 9691
test 이미지 개수 1127
validation 이미지 개수 1127
포켓몬 종류 151


In [9]:
train_gen = ImageDataGenerator(rescale=1./255)
val_gen = ImageDataGenerator(rescale=1./255)

train_generator = train_gen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=64,
    class_mode='categorical')
val_generator = val_gen.flow_from_directory(
    validation_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical')

Found 9691 images belonging to 151 classes.
Found 1127 images belonging to 151 classes.


In [13]:
resnet = tf.keras.applications.ResNet50(
    include_top=False,
    weights="imagenet",
    input_shape=(224, 224, 3))

In [14]:
resnet.summary()

Model: "resnet50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1_conv[0][0]                 
___________________________________________________________________________________________

In [16]:
model = tf.keras.Sequential()
model.add(resnet)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(n_classes, activation='softmax'))
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50 (Model)             (None, 7, 7, 2048)        23587712  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 151)               309399    
Total params: 23,897,111
Trainable params: 23,843,991
Non-trainable params: 53,120
_________________________________________________________________


In [18]:
model.layers[0].trainable = False

In [19]:
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

In [None]:
history = model.fit(
    train_generator,
    steps_per_epoch=100,  
    epochs=30,
    validation_data=val_generator,
    validation_steps=50
)

2021-12-20 22:53:01.528556: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:143] Filling up shuffle buffer (this may take a while): 25 of 152
2021-12-20 22:53:11.496646: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:143] Filling up shuffle buffer (this may take a while): 49 of 152
2021-12-20 22:53:21.445606: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:143] Filling up shuffle buffer (this may take a while): 73 of 152
2021-12-20 22:53:31.318220: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:143] Filling up shuffle buffer (this may take a while): 97 of 152
2021-12-20 22:53:41.646795: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:143] Filling up shuffle buffer (this may take a while): 118 of 152
