#EC5320 Week4a codes: CNN regression (Age detection) - FOR TEACHING

2022.3.23.<br>

Author: Hyunjoo Yang (hyang@sogang.ac.kr)<br><br>

This notebook uses CNN to do age prediction.<br><br>

Data source for face images: <br>
https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/ <br><br>
공인의 데이터를 3개의 소스로부터 긁어왔다. 50만장

Codes are based on: <br>
https://medium.com/analytics-vidhya/fastai-image-regression-age-prediction-based-on-image-68294d34f2ed <br><br>

For image augmentation, refer to: <br>
https://github.com/fastai/fastbook/blob/master/02_production.ipynb <br><br>

For AI and ethics, refer to: <br>
https://github.com/fastai/fastbook/blob/master/03_ethics.ipynb <br><br>

# 1. Install and import libraries

In [None]:
# upgrade fastai to the most recent version (v. 2.5.3)

%%capture
!pip install fastai --upgrade

In [None]:
import fastai
print(fastai.__version__)

from fastai.vision.all import *
#from fastai.text.all import *
#from fastai.collab import *
#from fastai.tabular.all import * 

from matplotlib.pyplot import imshow

In [None]:
import numpy as np
import pandas as pd
from matplotlib.pyplot import imshow
from google.colab import files

# 2. Download file

In [None]:
!wget -O imdb_crop_sample.zip 'https://www.dropbox.com/s/cwhhvl5trf4gtvh/imdb_crop_sample.zip?dl=0'

In [None]:
%%capture

!unzip imdb_crop_sample.zip -d faces

# 3. Prepare image file path + label dataframe

## 3.1 grab image file paths

In [None]:
# grab image file paths

import glob
img_full_path = pd.Series(glob.glob('faces/*.jpg'), name='my_file_path')
img_nm = pd.Series(img_full_path.str.split(pat="/").str[1], name='file_nm')

df_imdb_sample = pd.concat([img_full_path, img_nm], axis=1)
df_imdb_sample

In [None]:
# 파일명 뒤에서 8번째~4번째를 추출해서 photo taken 속성으로 생성

df_imdb_sample["photo_taken"] = df_imdb_sample['file_nm'].str[-8:-4]
df_imdb_sample.head()

In [None]:
# date of birth를 추출하기 위해서 파일명을 일단 _를 기준으로 분리시켜서 date of birth 속성에 삽입

df_imdb_sample["date_of_birth"] = df_imdb_sample['file_nm'].str.split("_")
df_imdb_sample.head()

In [None]:
# 리스트 내에서 date of birth가 속한 위치 추출

for i in range(0, len(df_imdb_sample)) :
  df_imdb_sample["date_of_birth"][i] = df_imdb_sample["date_of_birth"][i][2]
df_imdb_sample.head()

In [None]:
# training 코드와의 호환성을 위해서 연도만 추출

df_imdb_sample["date_of_birth"] = df_imdb_sample["date_of_birth"].str[0:4]
df_imdb_sample.head()

## 3.2 grab ground truth dataset (MATLAB data)

## 3.3 Merge my image file df with ground truth data

In [None]:
# check if missing variable
df_imdb_sample.isnull().sum()

In [None]:
df_imdb_sample.shape

## 3.4 Calculate age

In [None]:
# 안 된다.

df_imdb_sample['photo_taken'] = pd.to_numeric(df_imdb_sample['photo_taken'])
df_imdb_sample['date_of_birth'] = pd.to_numeric(df_imdb_sample['date_of_birth'])

In [None]:
# 위에 에러 내용 확인해보니 파일명이 잘못된 data가 하나 있는 정도여서 그냥 수동으로 제거
df_imdb_sample = df_imdb_sample.drop(df_imdb_sample.index[4890])

In [None]:
# 하나 없애주니까 잘 된다.
df_imdb_sample['photo_taken'] = pd.to_numeric(df_imdb_sample['photo_taken'])
df_imdb_sample['date_of_birth'] = pd.to_numeric(df_imdb_sample['date_of_birth'])

In [None]:
# calculate age
df_imdb_sample['age'] = df_imdb_sample['photo_taken'] - df_imdb_sample['date_of_birth']

# some guys seem to be greater than 100. some of these are paintings. remove these old guys
df_imdb_sample = df_imdb_sample[df_imdb_sample['age'] <= 100]

# some guys seem to be unborn in the data set
df_imdb_sample = df_imdb_sample[df_imdb_sample['age'] > 0]

In [None]:
# age의 outlier가 제거된 상태
df_imdb_sample.shape

In [None]:
# age가 추가됐다.
df_imdb_sample

In [None]:
df_imdb_sample['age'].hist()

# 9. Prepare data for CNN

In [None]:
# data block settings

my_random_seed = 42
my_batch_size = 64

In [None]:
from fastai.vision.data import ImageDataLoaders

In [None]:
df_imdb_simple = df_imdb_sample[['my_file_path','age']]
df_imdb_simple

In [None]:
dls = DataBlock(
    blocks=(ImageBlock, RegressionBlock),
    get_x=ColReader('my_file_path'), 
    get_y=ColReader('age'),
    splitter=RandomSplitter(valid_pct=0.2, seed=my_random_seed),
    item_tfms=Resize(128)
).dataloaders(df_imdb_simple)

In [None]:
len(dls.train_ds), len(dls.valid_ds)

In [None]:
# show image examples

dls.show_batch(max_n=16, nrows=2)

# 10. Train CNN model

In [None]:
#learn = 'change metric to rmse'

learn = cnn_learner(dls, resnet34, metrics=rmse)
learn.fine_tune(7)

In [None]:
learn.show_results()

# 11. Test using your own image

In [None]:
from google.colab import files

In [None]:
# upload an image file

uploaded = files.upload()
for fn in uploaded.keys():
  print('User uploaded file: {name}'.format(name=fn))

In [None]:
# get the file name of the uploaded file 

img_name = list(uploaded.keys())[0]

In [None]:
# predict

img = PILImage.create(uploaded[img_name])

img.show()

In [None]:
learn.predict(img)

# 12. Image augmentation

## 12.1 Radome Resized Crop

In [None]:
dls_aug = DataBlock(
    blocks=(ImageBlock, RegressionBlock),
    get_x=ColReader('my_file_path'), 
    get_y=ColReader('age'),
    splitter=RandomSplitter(valid_pct=0.2, seed=my_random_seed),
    item_tfms=RandomResizedCrop(128, min_scale=0.7)
).dataloaders(df_imdb_simple)

In [None]:
dls_aug.train.show_batch(max_n=8, nrows=2, unique=True)

In [None]:
learn = cnn_learner(dls_aug, resnet34, metrics=rmse).to_fp16() # resnet 18, 34, 50, 101, 152
learn.fine_tune(15)

## 12.2 aug_transforms

In [None]:
dls_aug2 = DataBlock(
    blocks=(ImageBlock, RegressionBlock),
    get_x=ColReader('my_file_path'), 
    get_y=ColReader('age'),
    splitter=RandomSplitter(valid_pct=0.2, seed=my_random_seed),
    item_tfms=Resize(128),
    batch_tfms=aug_transforms(mult=2)
).dataloaders(df_imdb_simple)

In [None]:
dls_aug2.train.show_batch(max_n=8, nrows=2, unique=True)

In [None]:
learn = cnn_learner(dls_aug2, resnet34, metrics=rmse).to_fp16() # resnet 18, 34, 50, 101, 152
learn.fine_tune(15)

# 13. Early stopping

In [None]:
learn.validate()

In [None]:
learn.path = Path('./')

learn = cnn_learner(dls_aug2, resnet34, metrics=rmse).to_fp16() # resnet 18, 34, 50, 101, 152
learn.fine_tune(50, cbs=[EarlyStoppingCallback(monitor='valid_loss', patience=3),SaveModelCallback(monitor='valid_loss')])

In [None]:
learn.validate()