# Initialize dataset object

In [1]:
import numpy as np, pandas as pd, os, time
import torch, torchvision

data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'
tf = 'MAX'
itime = time.time()
train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\t')
print(time.time() - itime)
val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\t')
print(time.time() - itime)

50.2965240479
58.1326179504


In [2]:
train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']
val_celltype = ['A549']
test_celltype = ['GM12878']
all_celltypes = train_celltypes + val_celltype + test_celltype

metadata_map = {}
metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']
metadata_map['celltype'] = all_celltypes

_split_dict = {
    'train': 0,
    'val-id': 1,
    'test': 2,
    'val-ood': 3
}
_split_names = {
    'train': 'Train',
    'val-id': 'Validation (ID)',
    'test': 'Test',
    'val-ood': 'Validation (OOD)',
}
_split_scheme = 'standard'

In [3]:
itime = time.time()
sequence_filename = os.path.join(data_dir, 'sequence.npz')
seq_arr = np.load(sequence_filename)
print(time.time() - itime)

itime = time.time()
_seq_bp = {}
for chrom in seq_arr:
    _seq_bp[chrom] = seq_arr[chrom]
    print(chrom, time.time() - itime)
print("Sequence read. Time: {}".format(time.time() - itime))

itime = time.time()
_dnase_allcelltypes = {}
for ct in all_celltypes:
    dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))
    dnase_npz_file = np.load(dnase_filename)
    _dnase_allcelltypes[ct] = {}
    for chrom in _seq_bp:
        _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]
    print(ct, time.time() - itime)
print("DNase read for all celltypes. Time: {}".format(time.time() - itime))

1.40137600899
('chr1', 4.365410089492798)
('chr2', 8.54686713218689)
('chr3', 11.915641069412231)
('chr4', 15.147382020950317)
('chr5', 18.221237182617188)
('chr6', 21.16081714630127)
('chr7', 23.87936806678772)
('chr8', 26.382845163345337)
('chr9', 28.802964210510254)
('chr10', 31.10539698600769)
('chr11', 33.392733097076416)
('chr12', 35.6597261428833)
('chr13', 37.56297421455383)
('chr14', 39.363978147506714)
('chr15', 41.089357137680054)
('chr16', 42.6117000579834)
('chr17', 43.9806342124939)
('chr18', 45.29493808746338)
('chr19', 46.26894497871399)
('chr20', 47.31300115585327)
('chr21', 48.139018058776855)
('chr22', 48.97876214981079)
('chrX', 51.61549210548401)
('H1-hESC', 24.14024806022644)
('HCT116', 47.97159004211426)
('HeLa-S3', 72.82926392555237)
('HepG2', 97.18733406066895)
('K562', 121.94148206710815)
('A549', 147.29550194740295)
('GM12878', 171.71312499046326)


In [10]:
# len(_dnase_allcelltypes)
all_df

NameError: name 'all_df' is not defined

In [4]:
tr_chrs = ['chr2', 'chr9', 'chr11']
te_chrs = ['chr1', 'chr8', 'chr21']
training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]
val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]
all_df = pd.concat([training_df, val_df])

#filter_msk = all_df['start'] >= 0
filter_msk = all_df['start']%1000 == 0
all_df = all_df[filter_msk]

AttributeError: 'module' object has no attribute 'isin'

In [None]:
itime = time.time()
pd_list = []
for ct in all_celltypes:
    tc_chr = all_df[['chr', 'start', 'stop', ct]]
    tc_chr.columns = ['chr', 'start', 'stop', 'y']
    tc_chr['celltype'] = ct
    pd_list.append(tc_chr)
metadata_df = pd.concat(pd_list)
print(time.time() - itime)

In [None]:
itime = time.time()
y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values
non_ambig_mask = (y_array != -1)
metadata_df['y'] = y_array
_metadata_df = metadata_df[non_ambig_mask]
_y_array = torch.LongTensor(y_array[non_ambig_mask])
print(time.time() - itime)

In [None]:
itime = time.time()
chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values
celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values
print(time.time() - itime)

In [None]:
train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)
val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)
train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)
val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)
test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)

split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)
split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']
split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']
split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']
split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']
_metadata_df['split'] = split_array
_split_array = split_array

In [13]:
from torch.utils.data import DataLoader
from data import dataset_attributes

ImportError: No module named data

In [15]:
from PIL import Image
import argparse
class ParseKwargs(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, dict())
        for value in values:
            key, value_str = value.split('=')
            if value_str.replace('-','').isnumeric():
                processed_val = int(value_str)
            elif value_str.replace('-','').replace('.','').isnumeric():
                processed_val = float(value_str)
            elif value_str in ['True', 'true']:
                processed_val = True
            elif value_str in ['False', 'false']:
                processed_val = False
            else:
                processed_val = value_str
            getattr(namespace, self.dest)[key] = processed_val

In [17]:
ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'
args_kw = "-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}".format(
    ROOTDIR).split()

parser = argparse.ArgumentParser()

# Dataset
parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)
parser.add_argument('--split_scheme', default='standard',
                    help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--root_dir', default=None, required=True,
                    help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
parser.add_argument('--download', default=False, action='store_true',
                    help='If true, tries to downloads the dataset if it does not exist in root_dir.')
parser.add_argument('--frac', type=float, default=1.0,
                    help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')

# Loaders
parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')
parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs
parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs

# Model
parser.add_argument(
    '--model',
    choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],
    default='resnet50')
parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
    help='keyword arguments for model initialization passed as key1=value1 key2=value2')
parser.add_argument('--train_from_scratch', action='store_true', default=False)

# Algorithm and objective
parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())
parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--groupby_fields', nargs='+', default=None)
parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default
parser.add_argument('--val_metric', default=None)

# Optimization
parser.add_argument('--n_epochs', type=int, default=4)
parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--weight_decay', type=float, required=True)
parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())
parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
parser.add_argument('--scheduler_metric_name')

# Evaluation
parser.add_argument('--evaluate_all_splits', action='store_true', default=False)
parser.add_argument('--additional_eval_splits', nargs='+', default=[])

# Misc
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int, default=None)
parser.add_argument('--save_best', action='store_true', default=False)
parser.add_argument('--save_last', action='store_true', default=False)
parser.add_argument('--save_outputs', action='store_true', default=False)
parser.add_argument('--no_group_logging', action='store_true', default=False)
parser.add_argument('--val_metric_decreasing', action='store_true', default=False)
parser.add_argument('--use_wandb', action='store_true', default=False)
parser.add_argument('--progress_bar', action='store_true', default=False)
parser.add_argument('--resume', default=False, action='store_true')
parser.add_argument('--eval_only', default=False, action='store_true')

args = parser.parse_args(args_kw)

NameError: name 'algorithm_constructors' is not defined

# get_input (idx)

In [21]:
idx = 3
this_metadata = _metadata_df.iloc[idx, :]

itime = time.time()
flank_size = 400
interval_start = this_metadata['start'] - flank_size
interval_end = this_metadata['stop'] + flank_size
dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]
seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]
data = np.column_stack([seq_this, dnase_this])
# print(time.time() - itime)

NameError: name '_metadata_df' is not defined

In [20]:
itime = time.time()
metadata_array = torch.stack(
    (torch.LongTensor(chr_ints), 
     torch.LongTensor(celltype_ints), 
     _y_array),
    dim=1)
print(time.time() - itime)

0.028102874755859375


In [34]:
#data.shape

ModuleNotFoundError: No module named 'torch_scatter'

In [157]:
data.shape
interval_end
# itime = time.time()
# np.save(os.path.join(data_dir, 'stmp.npy'), sa)
# print(time.time() - itime)

4600

# Run training experiment

In [167]:
cmdstr = "python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy"
cmdstr += " "
cmdstr += "--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg"
cmdstr += " "
cmdstr += "--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR"
cmdstr

'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'

NameError: name '_metadata_array' is not defined

In [165]:
import os, csv
import time
import argparse
import IPython
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import sys
from collections import defaultdict
# torch.multiprocessing.set_sharing_strategy('file_system')

# TODO: Replace this once we make wilds into an installed package
sys.path.insert(1, os.path.join(sys.path[0], '..'))

from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.utils import get_counts

from models.model_attributes import model_attributes
from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load
from train import train, evaluate
from data import dataset_attributes
from optimizer import optimizer_attributes
from scheduler import scheduler_attributes
from loss import losses
from utils import log_group_data
from algorithms.constructors import algorithm_constructors

ModuleNotFoundError: No module named 'torch_scatter'

In [2]:
from examples.models.model_attributes import model_attributes

In [3]:
def initialize_algorithm(args, datasets, train_grouper):
    train_dataset = datasets['train']['dataset']
    train_loader = datasets['train']['loader']

    # Configure the final layer of the networks used
    # The code below are defaults. Edit this if you need special config for your model.
    if (train_dataset.is_classification) and (train_dataset.y_size == 1):
        # For single-task classification, we have one output per class
        d_out = train_dataset.n_classes
    elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):
        # For multi-task binary classification (each output is the logit for each binary class)
        d_out = train_dataset.y_size
    elif (not train_dataset.is_classification):
        # For regression, we have one output per target dimension
        d_out = train_dataset.y_size
    else:
        raise RuntimeError('d_out not defined.')
        

    # Sanity checking input args
    if args.algorithm == 'groupDRO':
        assert args.train_loader_kwargs['uniform_over_groups']
    elif args.algorithm in ['deepCORAL', 'IRM']:
        assert args.train_loader == 'group'
        assert args.train_loader_kwargs['uniform_over_groups']
        assert args.train_loader_kwargs['distinct_groups']

    # Other config
    n_train_steps = len(train_loader) * args.n_epochs
#    prediction_fn = dataset_attributes[args.dataset]['prediction_fn']
    loss = losses[args.loss_function]
    metric = dataset_attributes[args.dataset]['metric']
    train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)
    is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0
    algorithm_constructor = algorithm_constructors[args.algorithm]
    algorithm = algorithm_constructor(
        args=args,
        d_out=d_out,
        grouper=train_grouper,
        loss=loss,
        metric=metric,
        n_train_steps=n_train_steps,
        is_group_in_train=is_group_in_train)
    return algorithm

ModuleNotFoundError: No module named 'utils'

In [None]:
def main():
    parser = argparse.ArgumentParser()

    # Dataset
    parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)
    parser.add_argument('--split_scheme', default='standard',
                        help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
    parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--root_dir', default=None, required=True,
                        help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
    parser.add_argument('--download', default=False, action='store_true',
                        help='If true, tries to downloads the dataset if it does not exist in root_dir.')
    parser.add_argument('--frac', type=float, default=1.0,
                        help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')

    # Loaders
    parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')
    parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs
    parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs

    # Model
    parser.add_argument(
        '--model',
        choices=model_attributes.keys(),
        default='resnet50')
    parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
        help='keyword arguments for model initialization passed as key1=value1 key2=value2')
    parser.add_argument('--train_from_scratch', action='store_true', default=False)

    # Algorithm and objective
    parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())
    parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--groupby_fields', nargs='+', default=None)
    parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default
    parser.add_argument('--val_metric', default=None)

    # Optimization
    parser.add_argument('--n_epochs', type=int, default=4)
    parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--weight_decay', type=float, required=True)
    parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())
    parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
    parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--evaluate_all_splits', action='store_true', default=False)
    parser.add_argument('--additional_eval_splits', nargs='+', default=[])

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int, default=None)
    parser.add_argument('--save_best', action='store_true', default=False)
    parser.add_argument('--save_last', action='store_true', default=False)
    parser.add_argument('--save_outputs', action='store_true', default=False)
    parser.add_argument('--no_group_logging', action='store_true', default=False)
    parser.add_argument('--val_metric_decreasing', action='store_true', default=False)
    parser.add_argument('--use_wandb', action='store_true', default=False)
    parser.add_argument('--progress_bar', action='store_true', default=False)
    parser.add_argument('--resume', default=False, action='store_true')
    parser.add_argument('--eval_only', default=False, action='store_true')

    args = parser.parse_args()

    # set device
    args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

    # Set defaults
    if args.groupby_fields is None:
        args.no_group_logging = True
    if args.val_metric is None:
        args.val_metric = dataset_attributes[args.dataset]['val_metric']

    ## Initialize logs
    if os.path.exists(args.log_dir) and args.resume:
        resume=True
        mode='a'
    else:
        resume=False
        mode='w'
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)

    # Record args
    log_args(args, logger)

    # Set random seed
    set_seed(args.seed)

    # Data
    full_dataset = dataset_attributes[args.dataset]['constructor'](
        root_dir=args.root_dir,
        download=args.download,
        split_scheme=args.split_scheme,
        **args.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = dataset_attributes[args.dataset]['transform'](args.model)
    if dataset_attributes[args.dataset].get('eval_transform') is None:
        eval_transform = dataset_attributes[args.dataset]['transform'](args.model)
    else:
        eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)

    train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=args.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split=='train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split,
            frac=args.frac,
            transform=transform)

        # Get loader
        shared_loader_kwargs = {
            'num_workers': args.num_workers,
            'pin_memory': not args.no_pin_memory,
            'batch_size': args.batch_size,
            'collate_fn': dataset_attributes[args.dataset]['collate']
        }

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=args.train_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                train_loader_kwargs=args.train_loader_kwargs,
                **shared_loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=args.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                **shared_loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose
        # Loggers
        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)

        if args.use_wandb:
            initialize_wandb(args)

    # Logging dataset info
    if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:
        log_grouper = CombinatorialGrouper(
            dataset=full_dataset,
            groupby_fields=['y'])
    elif args.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(args, datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(args, datasets, train_grouper)

    if not args.eval_only:
        ## Load saved results if resuming
        resume_success = False
        if resume:
            save_path = os.path.join(args.log_dir, 'last_model.pth')
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('_')[0])
                    for file in os.listdir(args.log_dir) if file.endswith('.pth')]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path)
                epoch_offset = prev_epoch + 1
                logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset=0
            best_val_metric=None


        train(algorithm,
              datasets,
              logger,
              args,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)
    else:
        best_model_path = os.path.join(args.log_dir, 'best_model.pth')
        best_epoch, best_val_metric = load(algorithm, best_model_path)
        evaluate(algorithm, datasets, best_epoch, logger)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()

if __name__=='__main__':
    main()
