In [None]:
import os

import numpy as np

import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt

import sklearn.model_selection as skms
import sklearn.preprocessing as skp
import sklearn.utils as sku
import sklearn.decomposition as skd

In [None]:
WORD_DATA_PATH = "./herbier/data_public/words/"
METADATA_PATH = "./herbier/data_public/ascii/words.txt"

IMAGE_HEIGHT = IMAGE_WIDTH = 128

CLASSES = ['a01-000u', 'a01-003u']
N_CLASSES = len(CLASSES)

DEBUG = True

## TODO
- voir pour ignorer le fichier METADATA
- limiter le nombre de sample à sélectionner

In [None]:
def _load_words_data(data_path, metadata_path):
    data = []

    with open(metadata_path, 'r') as file:
        for line in file:
            if not line.startswith("#"):
                components = line.strip().split(' ')
                word_id = components[0]
                
                parts = word_id.split('-')
                writer_id = '-'.join(parts[:2])  # e.g., 'a01-000u'
                image_subfolder = parts[0]       # e.g., 'a01'
                image_filename = f"{word_id}.png"
                image_path = os.path.join(data_path, image_subfolder, writer_id, image_filename)
                
                if os.path.exists(image_path):
                    try:
                        img = tf.io.read_file(image_path)
                        img = tf.image.decode_png(img)
                        data.append({
                            'image_path': image_path,
                            'writer_id': writer_id,  # Label = writer ID
                            'image_array': img
                        })
                    except tf.errors.InvalidArgumentError:
                        print(f"Image not found for word ID: {word_id} at {image_path}")
                else:
                    print(f"Image not found for word ID: {word_id} at {image_path}")

    return data

def load_words_data(data_path, metadata_path, selected_writers = []):
    if selected_writers == []:
        raise ValueError("selected_writers must be a non-empty list of writer IDs")

    data = []

    with open(metadata_path, 'r') as file:
        for line in file:
            if not line.startswith("#"):
                components = line.strip().split(' ')
                word_id = components[0]
                
                parts = word_id.split('-')
                writer_id = '-'.join(parts[:2])

                if writer_id in selected_writers:
                    image_subfolder = parts[0]
                    image_filename = f"{word_id}.png"
                    image_path = os.path.join(data_path, image_subfolder, writer_id, image_filename)
                    
                    if os.path.exists(image_path):
                        try:
                            img = tf.io.read_file(image_path)
                            img = tf.image.decode_png(img)
                            data.append({
                                'image_path': image_path,
                                'writer_id': writer_id,
                                'image_array': img
                            })
                        except tf.errors.InvalidArgumentError:
                            print(f"Image not found for word ID: {word_id} at {image_path}")
                    else:
                        print(f"Image not found for word ID: {word_id} at {image_path}")

    return data
 
words_data = load_words_data(WORD_DATA_PATH, METADATA_PATH, selected_writers=CLASSES)

if DEBUG:
  print(f"Loaded {len(words_data)} words.")
  for entry in words_data[:5]:
      print(f"  Writer ID: {entry['writer_id']}; image shape: {entry['image_array'].shape}")

if DEBUG: 
    print("number of writers: ", len(set([entry['writer_id'] for entry in words_data])))

if DEBUG:
  plt.figure(figsize=(10, 10))
  for i in range(25):
      plt.subplot(5, 5, i + 1)
      plt.xticks([])
      plt.yticks([])
      plt.grid(False)
      plt.imshow(words_data[i]['image_array'], cmap=plt.cm.binary)
      plt.xlabel(words_data[i]['writer_id'])
  plt.show()


In [None]:
# preprocessing data

def preprocess_data(data):
    labels = []
    images = []

    for entry in data:
        # Resize the image
        img = tf.image.resize(entry['image_array'], [IMAGE_HEIGHT, IMAGE_WIDTH])

        img = img.numpy().astype('float32') / 255.0  # Normalize and convert to float32
        images.append(img)

        labels.append(entry['writer_id'])

    return np.array(images), np.array(labels)


images, labels = preprocess_data(words_data)

X_train, X_test, y_train, y_test = skms.train_test_split(images, labels, test_size=0.2, random_state=42)

if DEBUG:
    print(f"X_train: {X_train.shape}; y_train: {y_train.shape}")
    print(f"X_test: {X_test.shape}; y_test: {y_test.shape}")

In [None]:
# encode labels
label_encoder = skp.LabelEncoder()
integer_encoded_labels = label_encoder.fit_transform(labels)
one_hot_encoded_labels = keras.utils.to_categorical(integer_encoded_labels)

X_train, X_test, y_train, y_test = skms.train_test_split(images, one_hot_encoded_labels, test_size=0.2, random_state=42)

if DEBUG:
    print(f"X_train: {X_train.shape}; y_train: {y_train.shape}")
    print(f"X_test: {X_test.shape}; y_test: {y_test.shape}")

In [None]:
model = keras.Sequential()

# Feature extraction layers
model.add(keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 1)))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Conv2D(128, (3, 3), activation='relu'))
model.add(keras.layers.BatchNormalization())

# Flattening the convolutional layer
model.add(keras.layers.Flatten())

# Dense layers for further processing
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dropout(0.5))

# Bottleneck layer for clustering
model.add(keras.layers.Dense(2, activation='relu'))  # For clustering

# Final classification layer (if needed for initial training)
model.add(keras.layers.Dense(N_CLASSES, activation='softmax'))

In [None]:
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
BATCH_SIZE = 32
EPOCHS = 20

history = model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(X_test, y_test))

In [None]:
features = model.predict(X_train)

if DEBUG:
  print("Features shape:", features.shape) 
  print("Features range per dimension:", features.min(axis=0), "to", features.max(axis=0))

scaler = skp.StandardScaler()
features_standardized = scaler.fit_transform(features)

pca = skd.PCA(n_components=N_CLASSES)
projected = pca.fit_transform(features_standardized)

if DEBUG:
  print("Explained variance by component:", pca.explained_variance_ratio_)

integer_class_labels = np.argmax(y_train, axis=1)

# Plotting
plt.scatter(projected[:, 0], projected[:, 1],
            c=integer_class_labels, edgecolor='none', alpha=0.5,
            cmap=plt.cm.get_cmap('Accent', 2))
plt.xlabel('component 1')
plt.ylabel('component 2')
plt.colorbar()