In [1]:
from train import train_cross_val
import pandas as pd
import os

# suppress warnings due from sklearn metrics
# early in training, metrics such as F1 score are poorly defined
# as model does may not predict any positive classes
import warnings
warnings.filterwarnings("ignore")


# About This Notebook

* This notebook provides barebones functionality to reproduce the results for solubility classification models. (Table. X)
* It also provides the ability to tweak some hyperparameters for the training process.

In [None]:
def run(config):
    out = []
    for i in range(0, 3):
        config['seed'] = i
        save_path = os.path.join("Trained_Models", f"{config['model']}", f'seed_{i}')
        val_metrics, test_metrics = train_cross_val(config, save_path=save_path)
        out.append(test_metrics)

    result_df = pd.concat(out, ignore_index=True)

    mean_results = result_df.mean()
    std_results = result_df.std()

    summary_df = pd.DataFrame({'mean': mean_results, 'std': std_results})
    summary_df = summary_df.sort_index()
    print(summary_df)

    result_df.to_csv(os.path.join("./results", config['model'] + '_results.csv'), index=False)

In [None]:
# choose from the following models: resnet18, efficientnet, convnext

config = {
    "data_dir": "./Solubility-Data",  # path to the data directory
    "model": None,  # will be set in the loop
    "lr": 0.0005, 
    "batch_size": 128, 
    "weight_decay": 0.001,
    "center_crop": (1080, 1080), # center crop before resizing
    "resize": (224, 224), # resizing to 224x224
    "degrees": 0, # random rotation by 0 degrees 
    "translate": (0.1, 0.1), # random translation by 10%4
    "scale_lower": 0.95, # random scaling to a minimum 95% of original size
    "scale_upper": 1.4,  # random scaling to a maximum 140% of original size
    "num_epochs": 30, 
    'seed':0, 
    'device': 'cuda'
    }


for model in ['resnet18', 'efficientnet', 'convnext']:
    config['model'] = model
    run(config)

Random seed set as 0
Epoch 0/29
----------
Train Loss: 1.4130
Test Metrics: {'accuracy@4': 0.5177, 'precision@4': 0.1897, 'recall@4': 0.3144, 'F1@4': 0.2366, 'accuracy@3': 0.6028, 'precision@3': 0.2009, 'recall@3': 0.3333, 'F1@3': 0.2507, 'accuracy@2': 0.6028, 'precision@2': 0.3014, 'recall@2': 0.5, 'F1@2': 0.3761}
Epoch 1/29
----------
Train Loss: 0.7179
Test Metrics: {'accuracy@4': 0.5745, 'precision@4': 0.4175, 'recall@4': 0.52, 'F1@4': 0.4456, 'accuracy@3': 0.6454, 'precision@3': 0.4714, 'recall@3': 0.5415, 'F1@3': 0.4914, 'accuracy@2': 0.6525, 'precision@2': 0.7626, 'recall@2': 0.5655, 'F1@2': 0.5102}
Epoch 2/29
----------
Train Loss: 0.4223
Test Metrics: {'accuracy@4': 0.5532, 'precision@4': 0.4659, 'recall@4': 0.3627, 'F1@4': 0.2742, 'accuracy@3': 0.6099, 'precision@3': 0.5357, 'recall@3': 0.3407, 'F1@3': 0.2663, 'accuracy@2': 0.6099, 'precision@2': 0.8036, 'recall@2': 0.5089, 'F1@2': 0.3953}
Epoch 3/29
----------
Train Loss: 0.2831
Test Metrics: {'accuracy@4': 0.8298, 'precisio