## Summary

In this notebook, we select the best model based on validation data.

----

## Imports

In [None]:
from pathlib import Path

## Parameters

In [None]:
NOTEBOOK_NAME = "05_select_best_model"
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

In [None]:
UNIQUE_ID = "191f05de"  # No attention
# UNIQUE_ID = "0007604c"  # 5-layer graph-conv with attention, batch_size=1
# UNIQUE_ID = "91fc9ab9"  # 4-layer graph-conv with attention, batch_size=4

In [None]:
DATA_DIR = NOTEBOOK_PATH.parent.joinpath("protein_train")

In [None]:
state_files = sorted(DATA_DIR.joinpath(UNIQUE_ID).glob("*.state"), key=lambda s: int(s.stem.split("-")[2].strip("d")))

## Workflow

In [None]:
avg_accuracies = []

for state_file_idx, state_file in enumerate(state_files):
    net = Net(
        x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
    )
    net.load_state_dict(torch.load(state_file))
    net.eval()
    net = net.to(device)

    results = []
    for i, data in enumerate(
        tqdm.tqdm_notebook(
            DataLoader(datasets["protein_valid"], shuffle=False, num_workers=1, batch_size=1, drop_last=False),
            leave=False,
            desc=f"{state_file_idx}",
        )
    ):
        data = data.to(device)
        data.y = data.x
        x_in = torch.ones_like(data.y) * 20.0
        is_missing = torch.ones(data.y.size(0), dtype=torch.bool)
        output = net(x_in, data.edge_index, data.edge_attr)
        output = torch.softmax(output, dim=1)
        _, predicted = output.max(dim=1)
        num_correct = float((predicted[is_missing] == data.y[is_missing]).sum())
        num_total = float(is_missing.sum())
        results.append(
            {"fraction_correct": num_correct / num_total, "num_correct": num_correct, "num_total": num_total}
        )

    oneshot_results_df = pd.DataFrame(results)

    datapoint = int(state_file.stem.split("-")[2].strip("d"))
    avg_accuracies.append((state_file_idx, datapoint, oneshot_results_df["fraction_correct"].mean()))
    print(avg_accuracies[-1])

In [None]:
_, datapoints, accuracies = np.array(avg_accuracies).T

In [None]:
fg, ax = plt.subplots()
# ax.axhline(0.24, color='k', linestyle='--')
ax.plot(datapoints, accuracies, label="valid")
# ax.plot(valid_datapoints, valid_accuracies, label="valid")
# ax.plot(test_datapoints, test_accuracies, label="test")
ax.legend()
ax.set_xlabel("Number of training data points")
ax.set_ylabel("Average test accuracy\nwith no starting residues")

In [None]:
best_model_index = np.argmax(accuracies)

best_model_index