# Data Preparation

In [None]:
!pip install -q neural-structured-learning

In [None]:
import tensorflow as tf
import neural_structured_learning as nsl

## prepare lookup table

In [None]:
TABEL_DEGREE_TXT = "/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/table_degree.txt"
TABEL_LABEL_TXT = "/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/table_label.txt"

In [None]:
N_DEGREE = 4
N_LABEL = 6

table_degree = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
    TABEL_DEGREE_TXT,
    tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
    tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER
), N_DEGREE)
table_label = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
    TABEL_LABEL_TXT,
    tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
    tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER
), N_LABEL)

## create tfrecord

In [None]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def parse_fn(path, label, degree):
    img_byte_str = tf.io.read_file(path)
    label = table_label.lookup(label)
    degree = table_degree.lookup(degree)
    return img_byte_str, label, degree

def create_example(img_byte_str, label, degree):
    features = {
        "image": _bytes_feature(img_byte_str),
        "degree": _int64_feature(degree),
        "label": _int64_feature(label),
    }
    return tf.train.Example(features=tf.train.Features(feature=features))

In [None]:
DB_PATH = "/data/aoi-wzs-p3-dip-prewave-saiap/metadata.db"
tfrecord_path_base = "/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training"

labels = [
    'NG-InversePolarity',
    'NG-MoreComp',
    'NG-NoneComp',
    'NG-OutsidePosition',
    'NG-UpsideDown',
    'OK',
]

for label in labels:
    dataset = tf.data.experimental.SqlDataset(
        "sqlite", DB_PATH,
        f"""select path, label, degree from metadata
        where
            label = '{label}' and
            degree >= 0 and
            component like '%Cap' and
            extension = 'png'
        """, (tf.string, tf.string, tf.string))
    dataset = dataset.map(parse_fn, tf.data.experimental.AUTOTUNE)

    tfrecord_path = f"{tfrecord_path_base}/{label}.tfrecord"
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for img_byte_str, label, degree in dataset:
            example = create_example([img_byte_str.numpy()], [label.numpy()], [degree.numpy()])
            writer.write(example.SerializeToString())

# Hyperparameters

In [None]:
class HParam():
    def __init__(self):
        self.channels = 3
        self.image_size = 32
        self.input_shape = (self.image_size, self.image_size, self.channels)
        
        self.shuffle_buffer = 10000
        self.batch_size = 1024
        self.valid_size = 3000
        
        self.epochs = 10
        self.steps_per_epoch = 100
        
        self.adv_multiplier = 2e-1
        self.adv_step_size = 2e-1

hparam = HParam()

## prepare training data

In [None]:
N_DEGREE = 4
N_LABEL = 6

def parse_image(image):
    image = tf.io.decode_png(image, hparam.channels)
    image = tf.image.resize_with_pad(image, hparam.image_size, hparam.image_size)
    return image

def parse_single_example(example_proto):
    feature_spec = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "degree": tf.io.FixedLenFeature([], tf.int64),
        "label": tf.io.FixedLenFeature([], tf.int64),
    }
    features = tf.io.parse_single_example(example_proto, feature_spec)

    features["image"] = parse_image(features["image"])
    features["degree"] = tf.one_hot(features["degree"], N_DEGREE)
    features["label"] = tf.one_hot(features["label"], N_LABEL)

    label = features["label"]
    return features, label

def convert_for_adv(features, label):
    return {
        "label": tf.cast(label, tf.float32),
        "degree": tf.cast(features["degree"], tf.float32),
        "image": features["image"],
    }

tfrecord_paths = [
    '/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/NG-InversePolarity.tfrecord',
    '/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/NG-MoreComp.tfrecord',
    '/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/NG-NoneComp.tfrecord',
    '/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/NG-OutsidePosition.tfrecord',
    '/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/NG-UpsideDown.tfrecord',
    '/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/OK.tfrecord',
]
datasets = [tf.data.TFRecordDataset(x).repeat() for x in tfrecord_paths]
dataset = tf.data.experimental.sample_from_datasets(datasets)

train_ds = dataset.skip(hparam.valid_size).shuffle(hparam.shuffle_buffer).map(parse_single_example, tf.data.experimental.AUTOTUNE).batch(hparam.batch_size).prefetch(tf.data.experimental.AUTOTUNE)
valid_ds = dataset.take(hparam.valid_size).shuffle(hparam.shuffle_buffer).map(parse_single_example, tf.data.experimental.AUTOTUNE).batch(hparam.batch_size).prefetch(tf.data.experimental.AUTOTUNE)

# Base Model

In [None]:
import tensorflow.keras.layers as l

def BaseModel():
    inputs = {
        "image": tf.keras.Input(hparam.input_shape, name="image"),
        "degree": tf.keras.Input([N_DEGREE], name="degree")
    }
    e = inputs["degree"]
    
    x = inputs["image"]
    x = l.Conv2D(64, 7, padding="same")(x)
    x = l.ReLU()(x)
    for filters in [128, 256, 512]:
        x = l.Conv2D(filters, 3, padding="same")(x)
        x = l.BatchNormalization()(x)
        x = l.ReLU()(x)
        x = l.MaxPool2D()(x)
    x = l.Conv2D(1024, 3, padding="same")(x)
    x = l.BatchNormalization()(x)
    x = l.ReLU()(x)
    x = l.GlobalAveragePooling2D()(x)
    x = l.Concatenate()([x, e])
    x = l.Dense(N_LABEL)(x)
    output = l.Activation("softmax", dtype="float32")(x)
    return tf.keras.Model(inputs=inputs, outputs=output)

In [None]:
metrics = [
    tf.keras.metrics.Recall(name="recall/NG-InversePolarity", class_id=0),
    tf.keras.metrics.Recall(name="recall/NG-MoreComp", class_id=1),
    tf.keras.metrics.Recall(name="recall/NG-NoneComp", class_id=2),
    tf.keras.metrics.Recall(name="recall/NG-OutsidePosition", class_id=3),
    tf.keras.metrics.Recall(name="recall/NG-UpsideDown", class_id=4),
    tf.keras.metrics.Recall(name="recall/OK", class_id=5),
    tf.keras.metrics.Precision(name="precision/NG-InversePolarity", class_id=0),
    tf.keras.metrics.Precision(name="precision/NG-MoreComp", class_id=1),
    tf.keras.metrics.Precision(name="precision/NG-NoneComp", class_id=2),
    tf.keras.metrics.Precision(name="precision/NG-OutsidePosition", class_id=3),
    tf.keras.metrics.Precision(name="precision/NG-UpsideDown", class_id=4),
    tf.keras.metrics.Precision(name="precision/OK", class_id=5),
]

In [None]:
logdir = "/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/base-model"
callbacks = [
    tf.keras.callbacks.TensorBoard(logdir, write_graph=False, profile_batch=0)
]

model = BaseModel()
model.compile("adam", "categorical_crossentropy", metrics=metrics)

model.fit(train_ds, validation_data=valid_ds,
          epochs=hparam.epochs, steps_per_epoch=hparam.steps_per_epoch,
          callbacks=callbacks,
          verbose=0)

## Adversarial Model

In [None]:
adv_train_ds = train_ds.map(convert_for_adv, tf.data.experimental.AUTOTUNE)
adv_valid_ds = valid_ds.map(convert_for_adv, tf.data.experimental.AUTOTUNE)

In [None]:
logdir = "/data/aoi-wzs-p3-dip-prewave-saiap/experiments/adversarial-training/adv-model"
callbacks = [
    tf.keras.callbacks.TensorBoard(logdir, write_graph=False, profile_batch=0)
]

adv_config = nsl.configs.make_adv_reg_config(
    multiplier=hparam.adv_multiplier,
    adv_step_size=hparam.adv_step_size,
)
base_model = BaseModel()
model = nsl.keras.AdversarialRegularization(
    base_model, label_keys=["label"], adv_config=adv_config)
model.compile("adam", "categorical_crossentropy", metrics=metrics)

model.fit(adv_train_ds, validation_data=adv_valid_ds,
          epochs=hparam.epochs, steps_per_epoch=hparam.steps_per_epoch,
          callbacks=callbacks,
          verbose=0)