In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import layers

In [3]:
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split, SubsetRandomSampler
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import v2
import numpy as np
import pandas as pd

In [4]:
def get_weights(train_dataset):
    indices = train_dataset.indices
    targets = torch.tensor([train_dataset.dataset.targets[i] for i in indices])
    class_counts = torch.bincount(targets)
    class_weights = 1. / class_counts.float()
    weights = class_weights[targets]
    return weights, class_counts

In [5]:
transform = v2.Compose([
    v2.Resize((256, 256)),
    v2.RandomRotation(10),
    v2.RandomHorizontalFlip(),
    v2.ToImage(),
    v2.Lambda(lambd=lambda x : torch.permute(x, [1, 2, 0])),
    v2.ToDtype(torch.float32, scale=True),
])


full_train_dataset = ImageFolder('../train', transform=transform)
test_dataset = ImageFolder('../test', transform=transform)
train_dataset, validation_dataset = random_split(full_train_dataset, [0.8, 0.2])

weights, num_classes = get_weights(train_dataset)

subset_size_train = int(0.35 * len(train_dataset))
subset_size_valid = int(0.35 * len(validation_dataset))
indices = torch.randperm(len(validation_dataset))[:subset_size_valid]
validation_sampler = SubsetRandomSampler(indices)

def custom_collate_fn(batch):
    batch = default_collate(batch)
    images, labels = batch
    labels = labels.view(-1, 1)
    return images, labels

train_loader = DataLoader(train_dataset, batch_size=32, sampler=WeightedRandomSampler(weights, num_samples=len(weights)), collate_fn=custom_collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)

In [5]:
def stem_block(dim):
    stem = keras.Sequential([
        layers.Conv2D(dim, 4, 4),
        layers.LayerNormalization(epsilon=1e-6)
    ])
    def compute(input):
        x = stem(input)
        return x
    return compute

def conv_next_block(dim, filters=5):
    block = keras.Sequential([
        layers.DepthwiseConv2D(filters, padding='same'),
        layers.LayerNormalization(epsilon=1e-6),
        layers.Conv2D(dim * 4, 1),
        layers.Activation(keras.activations.gelu),
        layers.Conv2D(dim, 1),
    ])
    def compute(input):
        x = block(input)
        return x + input
    return compute

def downsample_block(dim):
    block = keras.Sequential([
        layers.LayerNormalization(epsilon=1e-6),
        layers.Conv2D(dim, 2, 2)
    ])
    def compute(input):
        x = block(input)
        return x
    return compute

dim = 72
input = keras.Input(shape=(256, 256, 3))

x = stem_block(dim)(input)
x = conv_next_block(dim)(x)
x = layers.Dropout(0.4)(x)

x = downsample_block(dim * 2)(x)
x = conv_next_block(dim * 2)(x)
x = layers.Dropout(0.3)(x)

x = downsample_block(dim * 3)(x)
x = conv_next_block(dim * 3, 3)(x)
x = layers.Dropout(0.2)(x)
x = conv_next_block(dim * 3, 3)(x)
x = layers.Dropout(0.2)(x)
x = conv_next_block(dim * 3, 3)(x)
x = layers.Dropout(0.1)(x)


x = downsample_block(dim * 4)(x)
x = conv_next_block(dim * 4, 2)(x)
x = layers.Dropout(0.1)(x)

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(7)(x)

output = layers.Softmax()(x)


model = keras.Model(input, output)
model.summary()

In [None]:
name = 'lemon_4v2'

model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=f'{name}.keras',
    monitor='val_sparse_categorical_accuracy',
    mode='max',
    save_best_only=True)

early_stopping_callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=7)

model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=1e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(),
    ],
)

history = model.fit(train_loader, callbacks=[model_checkpoint_callback, early_stopping_callback], validation_data=validation_loader, epochs=300)