In [9]:
import sys
sys.path.append("..")

In [16]:
# Import Libraries
from utils import rotate_preserve_size
from loss import angular_loss_mae
import glob
import os
import numpy as np
import cv2
import random

from tensorflow.keras.models import Model
from tensorflow.keras import layers as L
import tensorflow as tf
import os
import pandas as pd
from tensorflow.keras.applications import Xception
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from loguru import logger
from tensorflow.keras.utils import Sequence

In [18]:
#Define conv base
conv_base = Xception(weights="imagenet", include_top=False, input_shape=(299, 299, 3))
for layer in conv_base.layers:
    layer.trainable = False

In [8]:
# Define model
img_input = L.Input(shape=(299, 299, 3))
x = conv_base(img_input)
x = L.Flatten()(x)
x = L.Dense(512, activation="relu")(x)
x = L.BatchNormalization()(x)
x = L.Dense(256, activation="relu")(x)
x = L.BatchNormalization()(x)
x = L.Dense(64, activation="relu")(x)
x = L.BatchNormalization()(x)
y = L.Dense(1, activation="linear")(x)
model = Model(img_input, y)

model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 299, 299, 3)]     0         
_________________________________________________________________
xception (Functional)        (None, 10, 10, 2048)      20861480  
_________________________________________________________________
flatten_1 (Flatten)          (None, 204800)            0         
_________________________________________________________________
dense_4 (Dense)              (None, 512)               104858112 
_________________________________________________________________
batch_normalization_11 (Batc (None, 512)               2048      
_________________________________________________________________
dense_5 (Dense)              (None, 256)               131328    
_________________________________________________________________
batch_normalization_12 (Batc (None, 256)               1024  

In [25]:
# Batch Generator
class RotGenerator(Sequence):
    def __init__(self, image_dir, batch_size, dim):
        self.files = glob.glob(os.path.join(image_dir, "*.jpg"))
        self.batch_size = batch_size
        self.dim = dim
        
    def __len__(self):
        if len(self.files) % self.batch_size == 0:
            return len(self.files) // self.batch_size
        return len(self.files) // self.batch_size + 1
    
    def __getitem__(self, idx):
        batch_slice = slice(idx * self.batch_size, (idx + 1) * self.batch_size)
        batch_files = self.files[batch_slice]
        
        X = np.zeros(shape=(len(batch_files), self.dim, self.dim, 3))
        y = np.zeros(shape=(len(batch_files), ))
        
        for i, f in enumerate(batch_files):
            # img = cv2.imread(f)
            angle = float(np.random.choice(range(0, 360)))
            img = rotate_preserve_size(f, angle, (self.dim, self.dim))
            
            X[i] = img
            y[i] = angle
        
        return X, y
    
    def on_epoch_end(self):
        random.shuffle(self.files)
            

In [31]:
# train
model.compile(loss=angular_loss_mae, optimizer="adam")

train_gen = RotGenerator("/Users/pidahbus/Documents/PhD/CS776A/Project/data/test/subhajyoti/", 4, 299)
model.fit(train_gen)

 36/313 [==>...........................] - ETA: 7:52 - loss: 87.9421

KeyboardInterrupt: 