In [1]:
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
from keras import layers

In [2]:
import logging

logging.getLogger().setLevel(logging.ERROR)
logging.basicConfig(level=logging.INFO)

In [3]:
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import v2
import numpy as np

In [4]:
def get_weights(train_dataset):
    indices = train_dataset.indices
    targets = torch.tensor([train_dataset.dataset.targets[i] for i in indices])
    class_counts = torch.bincount(targets)
    class_weights = 1. / class_counts.float()
    weights = class_weights[targets]
    return weights, class_counts

In [5]:
transform = v2.Compose([
    v2.Resize((256, 256)),
    v2.RandomRotation(10),
    v2.RandomHorizontalFlip(),
    v2.ToImage(),
    v2.Lambda(lambd=lambda x : torch.permute(x, [1, 2, 0])),
    v2.ToDtype(torch.float32, scale=True),
])


full_train_dataset = ImageFolder('../train', transform=transform)
test_dataset = ImageFolder('../test', transform=transform)
train_dataset, validation_dataset = random_split(full_train_dataset, [0.8, 0.2])

weights, num_classes = get_weights(train_dataset)

def custom_collate_fn(batch):
    batch = default_collate(batch)
    images, labels = batch
    labels = labels.view(-1, 1)
    return images, labels

train_loader = DataLoader(train_dataset, batch_size=16, sampler=WeightedRandomSampler(weights, len(weights)), collate_fn=custom_collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

In [6]:
def input_block(filters, p):
    block = keras.Sequential([
        layers.Conv2D(filters, kernel_size=(3, 3), padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ELU(),
        layers.MaxPooling2D(2),
        layers.Dropout(p)
    ])
    def compute(input):
        out = block(input)
        return out
    return compute

def output_block():
    block = keras.Sequential([
        layers.GlobalAveragePooling2D(),
        layers.Flatten(),
        layers.Dense(144),
        layers.Dense(7)
    ])
    def compute(input):
        out = keras.ops.softmax(block(input))
        return out
    return compute

def conv_block(filters, dim, kernel_size, p):
    block = keras.Sequential([
        layers.Conv2D(filters, kernel_size, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ELU(),
        layers.Dropout(p),

        layers.Conv2D(filters, kernel_size, dilation_rate=2, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Dropout(p),

        layers.Conv2D(filters, kernel_size, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Dropout(p),

        layers.Conv2D(filters+dim, kernel_size, padding="same", strides=2, use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Dropout(p),
    ])

    block_2 = keras.Sequential([
        layers.Conv2D(filters+dim, kernel_size, padding="same", strides=2, use_bias=False),
        layers.BatchNormalization(),
    ])

    def compute(input):
        out = keras.ops.relu(block(input) + block_2(input))
        return out
    return compute

def identity_block(filters, dim, kernel_size, p):
    block = keras.Sequential([
        layers.Conv2D(filters, kernel_size, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ELU(),
        layers.Dropout(p),

        layers.Conv2D(filters, kernel_size, dilation_rate=2, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Dropout(p),

        layers.Conv2D(filters, kernel_size, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Dropout(p),

        layers.Conv2D(filters+dim, kernel_size, padding="same", use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Dropout(p),
    ])

    def compute(input):
        out = keras.ops.relu(block(input) + input)
        return out
    return compute        

In [7]:
dim = 24
start = 44
input = keras.Input(shape=(256, 256, 3))
x = input_block(start, 0.5)(input)

x = conv_block(start, dim, 5, 0.4)(x)
x = identity_block(start, dim, 5, 0.4)(x)

x = conv_block(start, dim*2, 5, 0.3)(x)
x = identity_block(start, dim*2, 5, 0.3)(x)

x = conv_block(start, dim*3, 3, 0.2)(x)
x = identity_block(start, dim*3, 3, 0.2)(x)

x = conv_block(start, dim*4, 3, 0.2)(x)
x = identity_block(start, dim*4, 3, 0.2)(x)

x = conv_block(start, dim*5, 2, 0.1)(x)
x = identity_block(start, dim*5, 2, 0.1)(x)

output = output_block()(x)

model = keras.Model(input, output)
model.summary()

In [None]:
name = 'lemon_1_test'

model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=f'{name}.keras',
    monitor='val_sparse_categorical_accuracy',
    mode='max',
    save_best_only=True)

early_stopping_callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=7)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

history = model.fit(train_loader, callbacks=[model_checkpoint_callback, early_stopping_callback], validation_data=validation_loader, epochs=100)