In [None]:
!git clone https://github.com/phonhay103/NomOCR.git
%cd NomOCR

In [None]:
!unzip -f -j -o -qq all_images.zip -d all_images

In [None]:
!kaggle datasets download -f NOM_CGGAN.zip nhay103/NomOCR2 --force
!unzip -f -P 1 -j -qq NOM_CGGAN.zip -d all_images
!rm NOM_CGGAN.zip

In [None]:
import os
import wandb
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16
from tensorflow.keras.models import Model
import numpy as np
import pandas as pd

In [None]:
with open('ssim_16_16.txt') as f:
    ssim = f.read().splitlines()
    df = pd.DataFrame([line.split('\t') for line in ssim], columns=['filename', 'ssim'])
    df = df.astype({'ssim': np.float32})

df = df[df['ssim'] >= df['ssim'].quantile(0.25)]
gan_df = df.filename.values.tolist()

In [None]:
# Define the paths to the train and test data files
train_data_file = 'NomOCR_train.txt'
test_data_file = 'NomOCR_test.txt'
gan_data_file = 'NOM_CGGAN_train.txt'

# Define the path to the folder containing all the images
image_folder = 'all_images'

with open(gan_data_file) as f:
    gan_data = f.read().splitlines()
    gan_data = [line for line in gan_data if line.split('\t')[0] in gan_df]
    print('gan_data', len(gan_data))

with open(train_data_file) as f:
    train_data = f.read().splitlines()
    print('train_data', len(train_data))

train_data_file = 'temp_train_file.txt'
with open(train_data_file, 'w') as f:
    train_data = train_data + gan_data
    print('train_data_final', len(train_data))
    f.write('\n'.join(train_data).rstrip())

In [None]:
def load_data(data_file):
    with open(data_file, 'r') as file:
        filenames, labels = zip(*(line.strip().split('\t') for line in file))
        filenames = [os.path.join(image_folder, filename) for filename in filenames]
    return filenames, labels

# Load the train data file
train_filenames, train_labels = load_data(train_data_file)

# Load the test data file
test_filenames, test_labels = load_data(test_data_file)

In [None]:
# Create label_to_index
all_labels = set(train_labels + test_labels)
label_to_index = {sample: index for index, sample in enumerate(all_labels)}

train_labels = [label_to_index[label] for label in train_labels]
test_labels = [label_to_index[label] for label in test_labels]

# Get number of classes
num_classes = len(label_to_index)

In [None]:
# Define the image size and batch size for training
image_size = (64, 64)
batch_size = 128
lr=0.0001

In [None]:
def preprocess(filename, label):
    image = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, image_size)
    image = image / 255.0
    return image, label

def create_train_ds():
    train_ds = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
    train_ds = train_ds.map(preprocess)
    train_ds = train_ds.shuffle(buffer_size=1000)
    train_ds = train_ds.batch(batch_size)
    return train_ds

def create_test_ds():
    test_ds = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
    test_ds = test_ds.map(preprocess)
    test_ds = test_ds.batch(batch_size)
    return test_ds

def create_model(model):
    base_model = model(weights='imagenet', include_top=False, input_shape=(image_size[0], image_size[1], 3))
    
    x = base_model.output
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(1024, activation='relu')(x)
    predictions = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

    return Model(inputs=base_model.input, outputs=predictions)

In [None]:
# os.environ['WANDB_API_KEY'] = '<WANDB_API_KEY>'
train_ds = create_train_ds()
test_ds = create_test_ds()

# wandb.init(project='nom_script', name='VGG16 | CGGAN SSIM Q1')
# wandb_callback = wandb.keras.WandbCallback(
#     monitor='val_sparse_categorical_accuracy',
#     save_model=False
# )

model = create_model(VGG16)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
    loss='sparse_categorical_crossentropy',
    metrics=[
        'sparse_categorical_accuracy', 
            tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='sparse_top_3_categorical_accuracy'),
            tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='sparse_top_5_categorical_accuracy'),
    ]
)

history = model.fit(
    train_ds,
    batch_size=batch_size, 
    epochs=100, 
    validation_data=test_ds,
    verbose=1,
    # callbacks=[wandb_callback]
)

# wandb.finish()