In [None]:
"""
The canonical example of a function that can't be
learned with a simple linear model is XOR
"""
import json

import jax.numpy as np
from tqdm.autonotebook import tqdm

from colin_net.train import Experiment
from colin_net.metrics import accuracy

# Create Input Data and True Labels
inputs = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])

targets = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])

config = {
    "experiment_name": "xor_runs",
    "model_config": {
        "output_dim": 2,
        "input_dim": 2,
        "hidden_dim": 2,
        "num_hidden": 2,
        "activation": "tanh",
        "dropout_keep": None,
    },
    "random_seed": 42,
    "loss": "mean_squared_error",
    "regularization": None,
    "optimizer": "sgd",
    "learning_rate": 0.001,
    "batch_size": 4,
    "global_step": 5000,
    "log_every": 50,
}



In [None]:
print(json.dumps(Experiment.schema(), indent=4))

In [None]:
experiment = Experiment(**config)
experiment.dict()

In [None]:
model = experiment.create_model()

model

In [None]:
model.total_trainable_params()

In [None]:
import pandas as pd
pd.Series(model.trainable_params_by_layer())

In [None]:
update_generator = experiment.train(
    inputs, targets, inputs, targets, iterator_type="batch_iterator"
)

update_state = next(update_generator)
update_state

In [None]:
bar = tqdm(total=experiment.global_step)


for update_state in update_generator:
    
    if update_state.step % experiment.log_every == 0:
        
        model = update_state.model
        model = model.to_eval()
        
        predicted = model.predict_proba(inputs)
        acc_metric = float(accuracy(targets, predicted))
        
        bar.set_description(f"acc:{acc_metric*100}%, loss:{update_state.loss:.4f}")
        
        if acc_metric >= 0.99:
            print("Achieved Perfect Prediction!")
            break
        model = model.to_train()
    bar.update()

In [None]:
final_model = update_state.model
final_model = final_model.to_eval()


# Display Predictions
probabilties = final_model.predict_proba(inputs)

for gold, prob, pred in zip(targets, probabilties, np.argmax(probabilties, axis=1)):

    print(gold, prob, pred)

accuracy_score = float(accuracy(targets, probabilties))
print("Accuracy: ", accuracy_score)