In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0

devices = [tf.config.list_physical_devices('GPU')[0]]
tf.config.set_visible_devices(devices, 'GPU')
for device in devices:
    tf.config.experimental.set_memory_growth(device, True)

## gpu memory used
- 256 x 256 = 1125
- 512 x 512 = 1893
- 1024 x 1024 = 4965
- 1024 x 2048 = 7817

In [2]:
class custom_effnet(tf.keras.Model):
    def __init__(self, n_classes=4, input_shape=(256, 256, 3)):
        super().__init__()
        self.shape = input_shape
        self.base_model = EfficientNetB0(include_top=False, weights=None, input_shape=input_shape)
        self.dense = tf.keras.layers.Dense(n_classes)
        
    def call(self, x):
        x = self.base_model(x)
        x = tf.keras.layers.GlobalAveragePooling2D()(x)
        x = self.dense(x)
        return x

def train(model, inputs, outputs, optimizer, loss_fn):
    with tf.GradientTape() as tape:
        current_loss = loss_fn(model(inputs), outputs)
    grads = tape.gradient(current_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

def train_loop(model, loss_fn=tf.keras.losses.CategoricalCrossentropy(), input_shape=(256, 256, 3)):
    inputs = np.zeros(input_shape)[None, ...]
    outputs = np.zeros((1, 4))
    outputs[0, 0] = 1
    optimizer = tf.keras.optimizers.Adam()
    for i in range(3):
        train(model, inputs, outputs, optimizer, loss_fn)

In [3]:
n_classes = 4
input_shape = (1024, 2048, 3)
model = custom_effnet(n_classes, input_shape)
model.build(input_shape=(None, input_shape[0], input_shape[1], input_shape[2]))

train_loop(model, input_shape=input_shape)