In [None]:
import tensorflow as tf
import pandas as pd
import numpy as np

from collections import defaultdict, Counter

from dataset import load_dataset, load_dataset_info
from experiments import CORALExperiment

## Training Data

In [None]:
DATASETS = [
    ("LOC", "LOC", True),
    ("LOC", "Antwerp", False),
    ("LOCPortrait", "LOCPortrait", False),
]  # (training_dataset, test_dataset

In [None]:
map_technique = {
        "ambrotypes": 0,
        "cyanotypes": 1,
        "dry+plate+negatives": 2,
        "gelatin+silver+prints": 3,
        "acetate+negatives": 4
    }

In [None]:
def get_dataset(dataset, technique=False, test=False):
    df = load_dataset_info(f"../data/{dataset}")
    if test:
        df["set"] = "train"

    # Preprocess Years
    df = df.loc[(df["year"] >= 1850) & (df["year"] < 1930)]
    
    
    if technique:
        df = df.loc[df["technique"].isin(map_technique.keys())]
        print(df["technique"].unique())
        print(Counter(df["technique"]))
        df = df.groupby("technique").sample(841)
        # df["target"] = df["technique"]
        df["target"] = df["technique"].apply(lambda x : map_technique[x])
    else:
        df["target"] = df["year"] - df["year"].min()
        # df["target"] = df["target"].apply(str)

    min_year, max_year = df["year"].min(), df["year"].max()
    n_classes = df["target"].unique().shape[0]
    
    preprocess_config = {
        "preprocessing_function": tf.keras.applications.vgg16.preprocess_input
    }
    
    return load_dataset(
        f"../data/{dataset}",
        df=df,
        y_col="target",
        class_mode="raw",
        validation_split = 0.01 if test else 0.2,
        train_preprocess_config=preprocess_config,
        test_preprocess_config=preprocess_config,
    ),  n_classes

In [None]:
for train_dataset_name, test_dataset_name, is_technique in DATASETS:
    has_test_dataset = train_dataset_name != test_dataset_name
    
    (train_generator, val_generator, test_generator), n_classes \
        = get_dataset(train_dataset_name, technique=is_technique)
    
    if has_test_dataset:
        (test_generator, val_test_generator, _), _  = get_dataset(train_dataset_name, test=True, technique=is_technique)
        
    
    experiment = CORALExperiment(
        name=f"{train_dataset_name}_{test_dataset_name}_CORAL_Classification_{'technique' if is_technique else ''}",
        n_classes=n_classes,
    )
    
    experiment.run(
        train_generator, val_generator, test_generator,
        pretrain_epochs=50,
        finetune_epochs=50
    )