In [1]:
from train_classifier import train_cross_val
import pandas as pd
import os

# suppress warnings due from sklearn metrics
# early in training, metrics such as F1 score are poorly defined
# as model does may not predict any positive classes
import warnings
warnings.filterwarnings("ignore")

# About This Notebook

* This notebook provides barebones functionality to reproduce the results for solubility classification models. (Table. X)
* It also provides the ability to tweak some hyperparameters for the training process.

In [5]:
def run(config):
    out = []
    for i in range(0, 3):
        config['seed'] = i
        val_metrics, test_metrics = train_cross_val(config, save_path=None)
        out.append(test_metrics)

    result_df = pd.concat(out, ignore_index=True)

    mean_results = result_df.mean()
    std_results = result_df.std()

    summary_df = pd.DataFrame({'mean': mean_results, 'std': std_results})
    summary_df = summary_df.sort_index()
    print(summary_df)

    result_df.to_csv(os.path.join("./results", config['model'] + '_results.csv'), index=False)

In [None]:
# choose from the following models: resnet18, efficientnet, convnext

config = {
    "data_dir": "./data_wr",
    "model": "small_cnn",
    "lr": 0.0005, 
    "batch_size": 128, 
    "weight_decay": 0.001,
    "center_crop": (1080, 1080), # center crop before resizing
    "resize": (112, 112), # resizing to 224x224
    "degrees": 0, # random rotation by 0 degrees 
    "translate": (0.1, 0.1), # random translation by 10%4
    "scale_lower": 0.95, # random scaling to a minimum 95% of original size
    "scale_upper": 1.4,  # random scaling to a maximum 140% of original size
    "num_epochs": 30, 
    'seed':0, 
    'device': 'cuda'
    }


for model in ['small_cnn']:#['resnet18', 'efficientnet', 'convnext', 'small_cnn']:
    config['model'] = model
    run(config)

Random seed set as 0
Epoch 0/29
----------
Train Loss: 1.3526
Test Metrics: {'accuracy@4': 0.1379, 'precision@4': 0.0345, 'recall@4': 0.25, 'F1@4': 0.0606, 'accuracy@3': 0.3842, 'precision@3': 0.1281, 'recall@3': 0.3333, 'F1@3': 0.1851, 'accuracy@2': 0.3842, 'precision@2': 0.1921, 'recall@2': 0.5, 'F1@2': 0.2776}
Epoch 1/29
----------
Train Loss: 1.2329
Test Metrics: {'accuracy@4': 0.133, 'precision@4': 0.0897, 'recall@4': 0.2621, 'F1@4': 0.0741, 'accuracy@3': 0.133, 'precision@3': 0.1197, 'recall@3': 0.3495, 'F1@3': 0.0988, 'accuracy@2': 0.6158, 'precision@2': 0.3079, 'recall@2': 0.5, 'F1@2': 0.3811}
Epoch 2/29
----------
Train Loss: 1.1056
Test Metrics: {'accuracy@4': 0.1084, 'precision@4': 0.0271, 'recall@4': 0.25, 'F1@4': 0.0489, 'accuracy@3': 0.1084, 'precision@3': 0.0361, 'recall@3': 0.3333, 'F1@3': 0.0652, 'accuracy@2': 0.6158, 'precision@2': 0.3079, 'recall@2': 0.5, 'F1@2': 0.3811}
Epoch 3/29
----------
Train Loss: 0.9744
Test Metrics: {'accuracy@4': 0.1084, 'precision@4': 0.02