We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The following patch allows training without having a test set, i.e. for cases where the test set is seperate:
diff --git a/alignn/data.py b/alignn/data.py index 175b915..e70bebc 100644 --- a/alignn/data.py +++ b/alignn/data.py @@ -171,8 +171,9 @@ def get_id_train_val_test( # full train/val test split # ids = ids[::-1] id_train = ids[:n_train] - id_val = ids[-(n_val + n_test) : -n_test] # noqa:E203 - id_test = ids[-n_test:] + id_val = ids[-(n_val + n_test) : -n_test] if n_test > 0 else ids[-(n_val + n_test) :] # noqa:E203 + id_test = ids[n_test:] if n_test > 0 else [] + return id_train, id_val, id_test @@ -508,7 +509,7 @@ def get_train_val_loaders( classification=classification_threshold is not None, output_dir=output_dir, tmp_name="test_data", - ) + ) if len(dataset_test) > 0 else None collate_fn = train_data.collate # print("line_graph,line_dih_graph", line_graph, line_dih_graph) @@ -528,7 +529,7 @@ def get_train_val_loaders( val_loader = DataLoader( val_data, - batch_size=batch_size, + batch_size=1, shuffle=False, collate_fn=collate_fn, drop_last=True,
The text was updated successfully, but these errors were encountered:
Successfully merging a pull request may close this issue.
The following patch allows training without having a test set, i.e. for cases where the test set is seperate:
The text was updated successfully, but these errors were encountered: