In [None]:
import tensorflow as tf
import datetime
from sklearn import preprocessing
from tensorflow.keras.utils import to_categorical
from IPython.core.display import Image

from VCWA import Models, AttentionModels, Common, VideoDataGenerator

## Prepare Dataset

In [None]:
# mount cloud-storage bucket
# !mkdir /home/jupyter/bucket
!gcsfuse --implicit-dirs gfr-master-data-bucket /home/jupyter/bucket/

In [None]:
train_batch_size = 192 # use high batch size only for single-frame batches
test_batch_size = 4

### HMDB51

In [None]:
split_no = 1
datasetname = "hmdb51"
classes = 51
path = "/home/jupyter/"

dataset = Common.get_dataset(
    path + "processed_datasets/hmdb51_vid25", 
    path + "datasets/hmdb51_org_splits", 
    path + "processed_datasets/hmdb51_optflowl10_npz25", 
    split_no, 
    "hmdb51"
)
dataset

### UCF-101

In [None]:
split_no = 1
datasetname = "ucf101"
classes = 101
path = "/home/jupyter/"

dataset = Common.get_dataset(
    path + "processed_datasets/ucf101_vid25", 
    path + "datasets/ucfTrainTestlist", 
    path + "processed_datasets/ucf101_optflowl10_npz25", 
    split_no, 
    "ucf101"
)
dataset

## TODO: 2D-CNN

## TwoStream-Network

### Pre-Training individual Networks

#### Video Model

In [None]:
video_train_gen = VideoDataGenerator.VideoDataGenerator(
    dataset,
    target_size=(224, 224),
    batch_size=train_batch_size,
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    shape_format="images",
    single_frame=True,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True
)

video_test_gen = VideoDataGenerator.VideoDataGenerator(
    dataset,
    target_size=(224, 224),
    batch_size=test_batch_size,
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    shape_format="images"
)

In [None]:
# Load
# video_model = tf.keras.models.load_model("models/twostream_25_L10/ResNet50v2/video")

# Create new
# video_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), classes=classes, weights=None)
# video_model = AttentionModels.create_ResidualAttention_MobileNetV2(input_shape=(224, 224, 3), classes=classes)
video_model = AttentionModels.create_CBAM_MobileNetV2(input_shape=(224, 224, 3), classes=classes)

video_model.compile(
    loss="categorical_crossentropy", 
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=0.0001), 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)]
)

#### OptFlow Model

In [None]:
optflow_dataset = dataset.copy()
del optflow_dataset["path"]
optflow_dataset.rename(columns = {"optflow_path": "path"}, inplace=True)

optflow_train_gen = VideoDataGenerator.VideoDataGenerator(
    optflow_dataset,
    target_size=(224, 224),
    batch_size=train_batch_size,
    preprocessing_function=None,
    shape_format="images",
    single_frame=True,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True
)

optflow_test_gen = VideoDataGenerator.VideoDataGenerator(
    optflow_dataset,
    target_size=(224, 224),
    batch_size=test_batch_size,
    preprocessing_function=None,
    shape_format="images"
)

In [None]:
# Load
# video_model = tf.keras.models.load_model("models/twostream_25_L10/ResNet50v2/optflow")

# Create new
# optflow_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 20), classes=classes, weights=None)
# optflow_model = AttentionModels.create_ResidualAttention_MobileNetV2(input_shape=(224, 224, 20), classes=classes)
optflow_model = AttentionModels.create_CBAM_MobileNetV2(input_shape=(224, 224, 20), classes=classes)

optflow_model.compile(
    loss="categorical_crossentropy", 
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=0.0001), 
    metrics=["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(5)]
)

#### TwoStream Model

In [None]:
twostream_train_gen = VideoDataGenerator.VideoDataGenerator(
    dataset,
    target_size=(224, 224),
    optflow=True,
    batch_size=train_batch_size,
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    single_frame=True,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True
)

twostream_test_gen = VideoDataGenerator.VideoDataGenerator(
    dataset,
    target_size=(224, 224),
    optflow=True,
    batch_size=test_batch_size,
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input
)

## Training

### Combined Training

In [None]:
Models.train_optflow_model(
    video_model,
    optflow_model,
    video_train_gen,
    video_test_gen,
    optflow_train_gen,
    optflow_test_gen,
    twostream_test_gen,
    iterations=10,
    classes=classes,
    log_basedir=f"logs/fit_twostream_25_L10/{video_model.name}/",
    model_basedir="models/twostream_25_L10/"
)

### Individual Training

#### Video Model training

In [None]:
vid_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/fit_twostream_25_L10/video/" + video_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1
)

video_model.fit(
    video_train_gen, 
    epochs=1, 
    validation_data=video_test_gen,
    callbacks=[vid_tensorboard_callback])

In [None]:
video_model.save("models/twostream_25_L10/video/" + video_model.name)

#### Optflow Model training

In [None]:
optflow_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/fit_twostream_25_L10/optflow/" + optflow_model.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

optflow_model.fit(
    optflow_train_gen, 
    epochs=1, 
    validation_data=optflow_test_gen,
    callbacks=[optflow_tensorboard_callback])

In [None]:
optflow_model.save("models/twostream_25_L10/optflow/" + optflow_model.name)

#### TwoStream Model training

In [None]:
twostream = Models.assemble_TwoStreamModel(video_model, optflow_model, 51, fusion="average", recreate_top=True)

In [None]:
twostream_tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/fit_twostream_25_L10/twostream" + twostream.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

twostream.fit(
    twostream_train_gen,
    epochs=1,
    validation_data=twostream_test_gen,
    callbacks=[twostream_tensorboard_callback]
)

In [None]:
twostream_resnet50v2.save("models/twostream_25_L1/ResNet50v2")

## Display attention

In [None]:
x, y = twostream_test_gen.__getitem__(4)
x_video, x_optflow = x

In [None]:
attention = Models.get_twostream_attention(x_video[0], twostream_resnet50v2)
Models.video_to_gif(attention, "./attention.gif")

Image(filename="./attention.gif")

In [None]:
gradcam_attention = Models.get_twostream_gradcam(x_video[0], twostream_resnet50v2, "conv5_block3_3_conv")
Models.video_to_gif(gradcam_attention, "./gradcam_attention.gif")

Image(filename="./gradcam_attention.gif")

## LSTM

In [None]:
train_batch_size = 1
test_batch_size = 1

lstm_train_gen = VideoDataGenerator.VideoDataGenerator(
    dataset,
    target_size=(224, 224),
    batch_size=train_batch_size,
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
    rotation_range=20.0,
    shear_range=20.0,
    zoom_range=0.2,
    horizontal_flip=True
)

lstm_test_gen = VideoDataGenerator.VideoDataGenerator(
    dataset,
    target_size=(224, 224),
    batch_size=test_batch_size,
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input
)

### Backbone

In [None]:
backbone = tf.keras.applications.ResNet50V2()

### LSTM

In [None]:
lstm = Models.assemble_lstm(backbone, classes=51)#, recreate_top=True)

lstm.compile(
    loss="categorical_crossentropy",
    optimizer=tf.keras.optimizers.SGD(learning_rate=25 * 10**-5, momentum=0.9, decay=0.0005),
    metrics=[tf.keras.metrics.Accuracy(), tf.keras.metrics.TopKCategoricalAccuracy(5)])

In [None]:
lstm_callback = tf.keras.callbacks.TensorBoard(
    log_dir="logs/lstm/" + datasetname + "/" + lstm.name + "_" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 
    histogram_freq=1)

lstm.fit(
    lstm_train_gen,
    epochs=5,
    #validation_data=lstm_test_gen,
    callbacks=[lstm_callback]
)

In [None]:
lstm.evaluate(lstm_test_gen)

In [None]:
lstm.save("models/" + datasetname + "/" + lstm.name)

## TODO: 3D-CNN

## TODO: (2+1)D-CNN