# Initialize dataset object

In [1]:
import numpy as np, pandas as pd, os, time, torch, torchvision
data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'
tf = 'MAX'
itime = time.time()
train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\t')
print(time.time() - itime)
val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\t')
print(time.time() - itime)

57.8772239685
66.8270189762


In [2]:
train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']
val_celltype = ['A549']
test_celltype = ['GM12878']
all_celltypes = train_celltypes + val_celltype + test_celltype

metadata_map = {}
metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']
metadata_map['celltype'] = all_celltypes

_split_dict = {
    'train': 0,
    'val-id': 1,
    'test': 2,
    'val-ood': 3
}
_split_names = {
    'train': 'Train',
    'val-id': 'Validation (ID)',
    'test': 'Test',
    'val-ood': 'Validation (OOD)'
}
_split_scheme = 'standard'

In [4]:
itime = time.time()
sequence_filename = os.path.join(data_dir, 'sequence.npz')
seq_arr = np.load(sequence_filename)
print(time.time() - itime)

itime = time.time()
_seq_bp = {}
for chrom in seq_arr:
    _seq_bp[chrom] = seq_arr[chrom]
    print(chrom, time.time() - itime)
itime = time.time()
_dnase_allcelltypes = {}
for ct in all_celltypes:
    dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))
    dnase_npz_file = np.load(dnase_filename)
    _dnase_allcelltypes[ct] = {}
    for chrom in _seq_bp:
        _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]
    print(ct, time.time() - itime)

('H1-hESC', 25.299736976623535)
('HCT116', 49.68733310699463)
('HeLa-S3', 74.65905213356018)
('HepG2', 99.33112812042236)
('K562', 124.1327919960022)
('A549', 149.19999814033508)
('GM12878', 174.0277030467987)


In [None]:
class Beagle2(nn.Module):
    """
    Neural net models over genomic sequence.
    Input:
        - sequence_length: int (default 1000) 
        - Shape: (N, 5, sequence_length, 1) with batch size N.
    
    Output:
        - prediction (Tensor): float torch tensor of shape (N, )
    
    TODO: Finish docstring.
    """
    def __init__(self):
        """
        Parameters
        ----------
        sequence_length : int
        n_genomic_features : int
        """
        super(Beagle2, self).__init__()

        self.dropout = 0.3
        self.num_cell_types = 1
        self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))
        self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))
        self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))
        self.bn1 = nn.BatchNorm2d(300)
        self.bn2 = nn.BatchNorm2d(200)
        self.bn3 = nn.BatchNorm2d(200)
        self.maxpool1 = nn.MaxPool2d((3, 1))
        self.maxpool2 = nn.MaxPool2d((4, 1))
        self.maxpool3 = nn.MaxPool2d((4, 1))

        self.fc1 = nn.Linear(4200, 1000)
        self.bn4 = nn.BatchNorm1d(1000)

        self.fc2 = nn.Linear(1000, 1000)
        self.bn5 = nn.BatchNorm1d(1000)

        self.fc3 = nn.Linear(1000, self.num_cell_types)

    def forward(self, s):
        s = s.permute(0, 2, 1).contiguous()                          # batch_size x 4 x 1000
        s = s.view(-1, 5, 1000, 1)                                   # batch_size x 4 x 1000 x 1 [4 channels]
        s = self.maxpool1(F.relu(self.bn1(self.conv1(s))))           # batch_size x 300 x 333 x 1
        s = self.maxpool2(F.relu(self.bn2(self.conv2(s))))           # batch_size x 200 x 83 x 1
        s = self.maxpool3(F.relu(self.bn3(self.conv3(s))))           # batch_size x 200 x 21 x 1
        s = s.view(-1, 4200)
        conv_out = s

        s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training)  # batch_size x 1000
        #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training)  # batch_size x 1000
        
        
        s = self.fc3(s)

        return s, conv_out


class DanQ(nn.Module):
    def __init__(self, sequence_length, n_genomic_features):
        """
        Parameters
        ----------
        sequence_length : int
            Input sequence length
        n_genomic_features : int
            Total number of features to predict
        """
        super(DanQ, self).__init__()
        self.nnet = nn.Sequential(
            nn.Conv1d(4, 320, kernel_size=26),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(
                kernel_size=13, stride=13),
            nn.Dropout(0.2))

        self.bdlstm = nn.Sequential(
            nn.LSTM(
                320, 320, num_layers=1, batch_first=True, bidirectional=True))

        self._n_channels = math.floor(
            (sequence_length - 25) / 13)
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self._n_channels * 640, 925),
            nn.ReLU(inplace=True),
            nn.Linear(925, n_genomic_features),
            nn.Sigmoid())

    def forward(self, x):
        """Forward propagation of a batch.
        """
        out = self.nnet(x)
        reshape_out = out.transpose(0, 1).transpose(0, 2)
        out, _ = self.bdlstm(reshape_out)
        out = out.transpose(0, 1)
        reshape_out = out.contiguous().view(
            out.size(0), 640 * self._n_channels)
        predict = self.classifier(reshape_out)
        return predict


class DeepSEA(nn.Module):
    def __init__(self, sequence_length, n_genomic_features):
        """
        Parameters
        ----------
        sequence_length : int
        n_genomic_features : int
        """
        super(DeepSEA, self).__init__()
        conv_kernel_size = 8
        pool_kernel_size = 4

        self.conv_net = nn.Sequential(
            nn.Conv1d(4, 320, kernel_size=conv_kernel_size),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(
                kernel_size=pool_kernel_size, stride=pool_kernel_size),
            nn.Dropout(p=0.2),

            nn.Conv1d(320, 480, kernel_size=conv_kernel_size),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(
                kernel_size=pool_kernel_size, stride=pool_kernel_size),
            nn.Dropout(p=0.2),

            nn.Conv1d(480, 960, kernel_size=conv_kernel_size),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5))

        reduce_by = conv_kernel_size - 1
        pool_kernel_size = float(pool_kernel_size)
        self.n_channels = int(
            np.floor(
                (np.floor(
                    (sequence_length - reduce_by) / pool_kernel_size)
                 - reduce_by) / pool_kernel_size)
            - reduce_by)
        self.classifier = nn.Sequential(
            nn.Linear(960 * self.n_channels, n_genomic_features),
            nn.ReLU(inplace=True),
            nn.Linear(n_genomic_features, n_genomic_features),
            nn.Sigmoid())

    def forward(self, x):
        """Forward propagation of a batch.
        """
        out = self.conv_net(x)
        reshape_out = out.view(out.size(0), 960 * self.n_channels)
        predict = self.classifier(reshape_out)
        return predict

In [78]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class Beagle(nn.Module):
    """
    Neural net models over genomic sequence.
    Input:
        - sequence_length: int (default 1000) 
        - Shape: (N, 5, sequence_length, 1) with batch size N.
    
    Output:
        - prediction (Tensor): float torch tensor of shape (N, )
    
    TODO: Finish docstring.
    """
    def __init__(self):
        """
        Parameters
        ----------
        sequence_length : int
        n_genomic_features : int
        """
        super(Beagle, self).__init__()

        self.dropout = 0.3
        self.num_cell_types = 1
        self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))
        self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))
        self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))
        self.bn1 = nn.BatchNorm2d(300)
        self.bn2 = nn.BatchNorm2d(200)
        self.bn3 = nn.BatchNorm2d(200)
        self.maxpool1 = nn.MaxPool2d((3, 1))
        self.maxpool2 = nn.MaxPool2d((4, 1))
        self.maxpool3 = nn.MaxPool2d((4, 1))

        self.fc1 = nn.Linear(4200, 1000)
        self.bn4 = nn.BatchNorm1d(1000)

        self.fc2 = nn.Linear(1000, 1000)
        self.bn5 = nn.BatchNorm1d(1000)

        self.fc3 = nn.Linear(1000, self.num_cell_types)

    def forward(self, s):
        s = s.permute(0, 2, 1).contiguous()                          # batch_size x 5 x 1000
        s = s.view(-1, 5, 1000, 1)                                   # batch_size x 5 x 1000 x 1 [5 channels]
        s = self.maxpool1(F.relu(self.bn1(self.conv1(s))))           # batch_size x 300 x 333 x 1
        s = self.maxpool2(F.relu(self.bn2(self.conv2(s))))           # batch_size x 200 x 83 x 1
        s = self.maxpool3(F.relu(self.bn3(self.conv3(s))))           # batch_size x 200 x 21 x 1
        s = s.view(-1, 4200)
        conv_out = s

        s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training)  # batch_size x 1000
        s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training)  # batch_size x 1000
        
        s = self.fc3(s)

        return s, conv_out

In [86]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = Beagle2()
model = DanQ(50, 5)

lst = [(x[0], x[1].numel()) for x in model.named_parameters()]
#np.sum([x[1] for x in lst])
count_parameters(model)
lst

[('nnet.0.weight', 33280),
 ('nnet.0.bias', 320),
 ('bdlstm.0.weight_ih_l0', 409600),
 ('bdlstm.0.weight_hh_l0', 409600),
 ('bdlstm.0.bias_ih_l0', 1280),
 ('bdlstm.0.bias_hh_l0', 1280),
 ('bdlstm.0.weight_ih_l0_reverse', 409600),
 ('bdlstm.0.weight_hh_l0_reverse', 409600),
 ('bdlstm.0.bias_ih_l0_reverse', 1280),
 ('bdlstm.0.bias_hh_l0_reverse', 1280),
 ('classifier.1.weight', 592000),
 ('classifier.1.bias', 925),
 ('classifier.3.weight', 4625),
 ('classifier.3.bias', 5)]

In [48]:
tr_chrs = ['chr2', 'chr9', 'chr11']
te_chrs = ['chr1', 'chr8', 'chr21']
training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]
val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]
all_df = pd.concat([training_df, val_df])

#filter_msk = all_df['start'] >= 0
filter_msk = all_df['start']%1000 == 0
all_df = all_df[filter_msk]

AttributeError: 'module' object has no attribute 'isin'

In [49]:
print(np.__version__)

1.12.1


In [30]:
itime = time.time()
pd_list = []
for ct in all_celltypes:
    tc_chr = all_df[['chr', 'start', 'stop', ct]]
    tc_chr.columns = ['chr', 'start', 'stop', 'y']
    tc_chr['celltype'] = ct
    pd_list.append(tc_chr)
metadata_df = pd.concat(pd_list)
print(time.time() - itime)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  


1.659163236618042


In [31]:
itime = time.time()
y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values
non_ambig_mask = (y_array != -1)
metadata_df['y'] = y_array
_metadata_df = metadata_df[non_ambig_mask]
_y_array = torch.LongTensor(y_array[non_ambig_mask])
print(time.time() - itime)

3.0391879081726074


In [19]:
itime = time.time()
chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values
celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values
print(time.time() - itime)

12.390011310577393


In [53]:
train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)
val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)
train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)
val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)
test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)

split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)
split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']
split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']
split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']
split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']
_metadata_df['split'] = split_array
_split_array = split_array

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  if sys.path[0] == '':


# get_input (idx)

In [153]:
idx = 3
this_metadata = _metadata_df.iloc[idx, :]

itime = time.time()
flank_size = 400
interval_start = this_metadata['start'] - flank_size
interval_end = this_metadata['stop'] + flank_size
dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]
seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]
data = np.column_stack([seq_this, dnase_this])
# print(time.time() - itime)

In [154]:
data.shape
interval_end
# itime = time.time()
# np.save(os.path.join(data_dir, 'stmp.npy'), sa)
# print(time.time() - itime)

4600

In [78]:
itime = time.time()
metadata_array = torch.stack(
    (torch.LongTensor(chr_ints), 
     torch.LongTensor(celltype_ints), 
     _y_array),
    dim=1)
print(time.time() - itime)

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

In [156]:
_metadata_array

NameError: name '_metadata_array' is not defined

In [2]:
from examples.models.model_attributes import model_attributes

In [3]:
import os, csv
import time
import argparse
import IPython
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import sys
from collections import defaultdict

# TODO: Replace this once we make wilds into an installed package
sys.path.insert(1, os.path.join(sys.path[0], '..'))

from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.utils import get_counts

from examples.models.model_attributes import model_attributes
from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load
from examples.train import train
from examples.data import dataset_attributes
from examples.optimizer import optimizer_attributes
from examples.scheduler import scheduler_attributes
from examples.loss import losses
from examples.utils import log_group_data
from examples.algorithms.constructors import algorithm_constructors


def initialize_algorithm(args, datasets, train_grouper):
    train_dataset = datasets['train']['dataset']
    train_loader = datasets['train']['loader']

    # Configure the final layer of the networks used
    # The code below are defaults. Edit this if you need special config for your model.
    if (train_dataset.is_classification) and (train_dataset.y_size == 1):
        # For single-task classification, we have one output per class
        d_out = train_dataset.n_classes
    elif (not train_dataset.is_classification):
        # For regression, we have one output per target dimension
        d_out = train_dataset.y_size
    else:
        # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB
        pass

    # Sanity checking input args
    if args.algorithm == 'groupDRO':
        assert args.train_loader_kwargs['uniform_over_groups']
    elif args.algorithm in ['deepCORAL', 'IRM']:
        assert args.train_loader == 'group'
        assert args.train_loader_kwargs['uniform_over_groups']
        assert args.train_loader_kwargs['distinct_groups']

    # Other config
    n_train_steps = len(train_loader) * args.n_epochs
    prediction_fn = dataset_attributes[args.dataset]['prediction_fn']
    loss = losses[args.loss_function]
    metric_constructor = dataset_attributes[args.dataset]['metric']
    train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)
    is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0
    algorithm_constructor = algorithm_constructors[args.algorithm]
    algorithm = algorithm_constructor(
        args=args,
        d_out=d_out,
        grouper=train_grouper,
        prediction_fn=prediction_fn,
        loss=loss,
        metric_constructor=metric_constructor,
        n_train_steps=n_train_steps,
        is_group_in_train=is_group_in_train)
    return algorithm

ModuleNotFoundError: No module named 'utils'

In [None]:
parser = argparse.ArgumentParser()

# Dataset
parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)
parser.add_argument('--split_scheme', default='standard',
                    help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
parser.add_argument('--root_dir', default=None, required=True,
                    help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
parser.add_argument('--download', default=False, action='store_true',
                    help='If true, tries to downloads the dataset if it does not exist in root_dir.')
parser.add_argument('--frac', type=float, default=1.0,
                    help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')

# Loaders
parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')
parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')
parser.add_argument('--batch_size', type=int, default=32)

# Model
parser.add_argument(
    '--model',
    choices=model_attributes.keys(),
    default='resnet50')
parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
    help='keyword arguments for model initialization passed as key1=value1 key2=value2')
parser.add_argument('--train_from_scratch', action='store_true', default=False)

# Algorithm and objective
parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())
parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--groupby_fields', nargs='+', default=None)
parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default
parser.add_argument('--val_metric', default=None)

# Optimization
parser.add_argument('--n_epochs', type=int, default=4)
parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--weight_decay', type=float, required=True)
parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())
parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
parser.add_argument('--scheduler_metric_name')

# Evaluation
parser.add_argument('--evaluate_all_splits', action='store_true', default=False)
parser.add_argument('--additional_eval_splits', nargs='+', default=[])

# Misc
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int, default=None)
parser.add_argument('--save_best', action='store_true', default=False)
parser.add_argument('--save_last', action='store_true', default=False)
parser.add_argument('--save_outputs', action='store_true', default=False)
parser.add_argument('--no_group_logging', action='store_true', default=False)

parser.add_argument('--resume', default=False, action='store_true')

args = parser.parse_args()

# Set defaults
if args.groupby_fields is None:
    args.no_group_logging = True
if args.val_metric is None:
    args.val_metric = dataset_attributes[args.dataset]['val_metric']

## Initialize logs
if os.path.exists(args.log_dir) and args.resume:
    resume=True
    mode='a'
else:
    resume=False
    mode='w'
if not os.path.exists(args.log_dir):
    os.makedirs(args.log_dir)
logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)

# Record args
log_args(args, logger)

# Set random seed
set_seed(args.seed)

# Data
full_dataset = dataset_attributes[args.dataset]['constructor'](
    root_dir=args.root_dir,
    download=args.download,
    split_scheme=args.split_scheme)

# To implement data augmentation (i.e., have different transforms
# at training time vs. test time), modify these two lines:
train_transform = dataset_attributes[args.dataset]['transform'](args.model)
eval_transform = dataset_attributes[args.dataset]['transform'](args.model)

train_grouper = CombinatorialGrouper(
    dataset=full_dataset,
    groupby_fields=args.groupby_fields)

datasets = defaultdict(dict)
for split in full_dataset.split_dict.keys():
    if split=='train':
        transform = train_transform
        verbose = True
    elif split == 'val':
        transform = eval_transform
        verbose = True
    else:
        transform = eval_transform
        verbose = False
    # Get subset
    datasets[split]['dataset'] = full_dataset.get_subset(
        split,
        frac=args.frac,
        transform=transform)

    # Get loader
    shared_loader_kwargs = {
        'num_workers': 4,
        'pin_memory': True,
        'batch_size': args.batch_size,
        'collate_fn': dataset_attributes[args.dataset]['collate']
    }

    if split == 'train':
        datasets[split]['loader'] = get_train_loader(
            loader=args.train_loader,
            dataset=datasets[split]['dataset'],
            grouper=train_grouper,
            train_loader_kwargs=args.train_loader_kwargs,
            **shared_loader_kwargs)
    else:
        datasets[split]['loader'] = get_eval_loader(
            loader=args.eval_loader,
            dataset=datasets[split]['dataset'],
            grouper=train_grouper,
            **shared_loader_kwargs)

    # Set fields
    datasets[split]['split'] = split
    datasets[split]['name'] = full_dataset.split_names[split]
    datasets[split]['verbose'] = verbose
    # Loggers
    datasets[split]['eval_logger'] = CSVBatchLogger(
        os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)
    datasets[split]['algo_logger'] = CSVBatchLogger(
        os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)

# Logging dataset info
if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:
    log_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=['y'])
elif args.no_group_logging:
    log_grouper = None
else:
    log_grouper = train_grouper
log_group_data(args, datasets, log_grouper, logger)

## Initialize algorithm
algorithm = initialize_algorithm(args, datasets, train_grouper)

## Load saved results if resuming
if resume:
    save_path = os.path.join(args.log_dir, 'last_model.pth')
    prev_epoch, best_val_metric = load(algorithm, save_path)
    epoch_offset = prev_epoch + 1
else:
    epoch_offset=0
    best_val_metric=None

train(algorithm,
      datasets,
      logger,
      args,
      epoch_offset=epoch_offset,
      best_val_metric=best_val_metric)

logger.close()
for split in datasets:
    datasets[split]['eval_logger'].close()
    datasets[split]['algo_logger'].close()