In [1]:
import sys
sys.path.append("..")

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [23]:
# Import Libraries
from transformers import TFAutoModel, ViTForImageClassification
from utils import rotate_preserve_size
from loss import angular_loss_mae
import glob
import os
import numpy as np
import cv2
import random

from tensorflow.keras.models import Model
from tensorflow.keras import layers as L
import tensorflow as tf
import os
import pandas as pd
from tensorflow.keras.applications import Xception, EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from loguru import logger
from tensorflow.keras.utils import Sequence
from tensorflow.keras.optimizers import Adadelta
from generator import RotGenerator, ValidationTestGenerator

In [4]:
# Parameters
IMAGE_SIZE = 224

In [61]:
# get vit model
vit_base = TFAutoModel.from_pretrained("google/vit-base-patch16-224")
# vit_base = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

Some layers from the model checkpoint at google/vit-base-patch16-224 were not used when initializing TFViTModel: ['classifier']
- This IS expected if you are initializing TFViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit/pooler/dense/bias:0', 'vit/pooler/dense/kernel:0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [62]:
# Define model
img_input = L.Input(shape=(3,IMAGE_SIZE, IMAGE_SIZE))
x = vit_base(img_input)
y = L.Dense(1, activation="linear")(x[-1])

model = Model(img_input, y)
model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_10 (InputLayer)        [(None, 3, 224, 224)]     0         
_________________________________________________________________
tf_vi_t_model_1 (TFViTModel) TFBaseModelOutputWithPool 86389248  
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 769       
Total params: 86,390,017
Trainable params: 86,390,017
Non-trainable params: 0
_________________________________________________________________


In [74]:
from transformers import ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

class RotGenerator(Sequence):
    def __init__(self, image_dir, batch_size, dim, channels_first=False, is_vit=False):
        self.files = glob.glob(os.path.join(image_dir, "*.jpg"))
        self.batch_size = batch_size
        self.dim = dim
        self.channels_first = channels_first
        self.is_vit = is_vit
        
    def __len__(self):
        if len(self.files) % self.batch_size == 0:
            return len(self.files) // self.batch_size
        return len(self.files) // self.batch_size + 1
    
    def __getitem__(self, idx):
        batch_slice = slice(idx * self.batch_size, (idx + 1) * self.batch_size)
        batch_files = self.files[batch_slice]
        
        # X = np.zeros(shape=(len(batch_files), self.dim, self.dim, 3))
        # y = np.zeros(shape=(len(batch_files), ))

        X = []
        y = []
        
        for i, f in enumerate(batch_files):
            try:
                angle = float(np.random.choice(range(0, 360)))
                img = rotate_preserve_size(f, angle, (self.dim, self.dim))
                img = np.array(img)
                if self.is_vit:
                    X.append(img)
                else:
                    if self.channels_first:
                        img = img.transpose(2, 0, 1)

                    img = np.expand_dims(img, axis=0)
                    X.append(img)
                    # X[i] = img
                    # y[i] = angle
                y.append(angle)

            except:
                pass
        
        if self.is_vit:
            X = feature_extractor(images=X, return_tensors="pt")["pixel_values"]
            X = np.array(X)
        else:
            X = np.concatenate(X, axis=0)
        y = np.array(y)

        return X, y
    
    def on_epoch_end(self):
        random.shuffle(self.files)

In [83]:
class ValidationTestGenerator(Sequence):
    def __init__(self, image_dir, df_label_path, batch_size, dim, mode, channels_first=False, is_vit=False):
        self.image_dir = image_dir
        self.batch_size = batch_size
        self.dim = dim
        self.mode = mode
        self.channels_first = channels_first
        self.is_vit = is_vit
        
        df_label = pd.read_csv(df_label_path)
        self.df = df_label[df_label["mode"] == self.mode].reset_index(drop=True)
        
    def __len__(self):
        total = self.df.shape[0]
        if total % self.batch_size == 0:
            return total // self.batch_size
        return total // self.batch_size + 1
    
    def __getitem__(self, idx):
        batch_slice = slice(idx * self.batch_size, (idx + 1) * self.batch_size)
        df_batch = self.df[batch_slice].reset_index(drop=True).copy()
        
        # X = np.zeros(shape=(len(df_batch), self.dim, self.dim, 3))
        # y = np.zeros(shape=(len(df_batch), ))

        X = []
        y = []
        
        for i in range(len(df_batch)):
            try:
                angle = df_batch.angle[i]
                path = os.path.join(self.image_dir, df_batch.image[i])
                img = rotate_preserve_size(path, angle, (self.dim, self.dim))

                img = np.array(img)
                if self.is_vit:
                    X.append(img)
                else:
                    if self.channels_first:
                        img = img.transpose(2, 0, 1)

                    img = np.expand_dims(img, axis=0)
                    X.append(img)
                    # X[i] = img
                    # y[i] = angle
                y.append(angle)

            except:
                pass
        
        if self.is_vit:
            X = feature_extractor(images=X, return_tensors="pt")["pixel_values"]
            X = np.array(X)
        else:
            X = np.concatenate(X, axis=0)
        y = np.array(y)

        return X, y

(array([[[[-0.29411763, -0.31764704, -0.30196077, ..., -0.7411765 ,
           -0.7411765 , -0.7411765 ],
          [-0.3333333 , -0.32549018, -0.27843136, ..., -0.7411765 ,
           -0.7411765 , -0.7490196 ],
          [-0.3098039 , -0.29411763, -0.25490195, ..., -0.7490196 ,
           -0.7411765 , -0.7411765 ],
          ...,
          [-0.56078434,  0.15294123,  0.2941177 , ...,  0.05098045,
           -0.09019607,  0.13725495],
          [-0.15294117, -0.04313725,  0.24705887, ..., -0.17647058,
           -0.27058822, -0.08235294],
          [-0.19999999, -0.0745098 ,  0.19215691, ...,  0.07450986,
           -0.06666666,  0.11372554]],
 
         [[-0.26274508, -0.27843136, -0.27058822, ..., -0.60784316,
           -0.60784316, -0.6       ],
          [-0.3098039 , -0.30196077, -0.25490195, ..., -0.60784316,
           -0.6156863 , -0.60784316],
          [-0.29411763, -0.26274508, -0.2235294 , ..., -0.5921569 ,
           -0.5921569 , -0.5921569 ],
          ...,
          [-0

In [75]:
train_gen = RotGenerator("/data/chandanp/train2017/", 16, IMAGE_SIZE, is_vit=True)
train_gen.__getitem__(2)

(array([[[[-0.9372549 , -0.8666667 , -0.7647059 , ..., -0.64705884,
           -0.64705884, -0.64705884],
          [-0.62352943, -0.6156863 , -0.62352943, ..., -0.64705884,
           -0.62352943, -0.6392157 ],
          [-0.6       , -0.6156863 , -0.6       , ..., -0.62352943,
           -0.5764706 , -0.6       ],
          ...,
          [ 0.11372554,  0.06666672,  0.05098045, ...,  0.73333335,
            0.6156863 ,  0.5294118 ],
          [ 0.16078436,  0.23921573,  0.30196083, ...,  0.6       ,
            0.6862745 ,  0.6862745 ],
          [ 0.17647064,  0.36470592,  0.17647064, ...,  0.654902  ,
            0.62352943,  0.6156863 ]],
 
         [[-0.99215686, -0.9843137 , -0.99215686, ..., -0.7490196 ,
           -0.7411765 , -0.75686276],
          [-0.96862745, -0.96862745, -0.96862745, ..., -0.75686276,
           -0.75686276, -0.75686276],
          [-0.96862745, -0.96862745, -0.96862745, ..., -0.75686276,
           -0.78039217, -0.8039216 ],
          ...,
          [-0

In [88]:
# train
model.compile(loss=angular_loss_mae, optimizer=Adadelta(learning_rate=0.1))

train_gen = RotGenerator("/data/chandanp/train2017/", 16, IMAGE_SIZE, is_vit=True)
val_gen = ValidationTestGenerator(image_dir="/data/subhadip/data/", 
                                  df_label_path="/data/subhadip/data/validation-test.csv",
                                  batch_size=32, dim=IMAGE_SIZE, mode="valid", is_vit=True)
cp = ModelCheckpoint("/data/subhadip/weights/model-vit-ang-loss.h5", save_weights_only=False, 
                     save_best_only=True, monitor="loss")
reduce_lr = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, min_lr=1e-5)
es = EarlyStopping(monitor="val_loss", patience=5)
model.fit(train_gen, validation_data=val_gen, epochs=10000, callbacks=[cp, es, reduce_lr])

Epoch 1/10000

KeyboardInterrupt: 