# run_expt.py contents

## 1) Preamble

In [2]:
import os, psutil; print(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)

396.69921875




In [1]:
# import pyBigWig
# %timeit bw = pyBigWig.open("/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig")
%timeit bw.values('chr1', 10000, 22800, numpy=True)

NameError: name 'bw' is not defined

In [1]:
import os, csv, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import time
import argparse
import numpy as np, pandas as pd
import torch
import torch.nn as nn
import torchvision
import pyBigWig
from collections import defaultdict

from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper

from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool
from train import train, evaluate
from algorithms.initializer import initialize_algorithm
from transforms import initialize_transform
from configs.utils import populate_defaults
import configs.supported as supported



In [2]:
''' set default hyperparams in default_hyperparams.py '''
parser = argparse.ArgumentParser()

# Required arguments
parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)
parser.add_argument('--algorithm', required=True, choices=supported.algorithms)
parser.add_argument('--root_dir', required=True,
                    help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')

# Dataset
parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',
                    help='If true, tries to downloads the dataset if it does not exist in root_dir.')
parser.add_argument('--frac', type=float, default=1.0,
                    help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')

# Loaders
parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--train_loader', choices=['standard', 'group'])
parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')
parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')
parser.add_argument('--n_groups_per_batch', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--eval_loader', choices=['standard'], default='standard')

# Model
parser.add_argument('--model', choices=supported.models)
parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
    help='keyword arguments for model initialization passed as key1=value1 key2=value2')

# Transforms
parser.add_argument('--train_transform', choices=supported.transforms)
parser.add_argument('--eval_transform', choices=supported.transforms)
parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.')
parser.add_argument('--resize_scale', type=float)
parser.add_argument('--max_token_length', type=int)

# Objective
parser.add_argument('--loss_function', choices = supported.losses)

# Algorithm
parser.add_argument('--groupby_fields', nargs='+')
parser.add_argument('--group_dro_step_size', type=float)
parser.add_argument('--coral_penalty_weight', type=float)
parser.add_argument('--irm_lambda', type=float)
parser.add_argument('--irm_penalty_anneal_iters', type=int)
parser.add_argument('--algo_log_metric')

# Model selection
parser.add_argument('--val_metric')
parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')

# Optimization
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--optimizer', choices=supported.optimizers)
parser.add_argument('--lr', type=float)
parser.add_argument('--weight_decay', type=float)
parser.add_argument('--max_grad_norm', type=float)
parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})

# Scheduler
parser.add_argument('--scheduler', choices=supported.schedulers)
parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
parser.add_argument('--scheduler_metric_name')

# Evaluation
parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--eval_splits', nargs='+', default=[])
parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--eval_epoch', default=None, type=int)

# Misc
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int)
parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')
parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)

_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=<function parse_bool at 0x7fd748398700>, choices=None, help=None, metavar=None)

In [3]:
argstr_camelyon = "--dataset camelyon17 --algorithm ERM --root_dir data"
config_camelyon = parser.parse_args(argstr_camelyon.split())
config_camelyon = populate_defaults(config_camelyon)

argstr_bdd100k = "--dataset bdd100k --algorithm ERM --root_dir data"
config_bdd100k = parser.parse_args(argstr_bdd100k.split())
config_bdd100k = populate_defaults(config_bdd100k)

argstr_encode = "--dataset encode-tfbs --algorithm ERM --root_dir data"
config_encode = parser.parse_args(argstr_encode.split())
config_encode = populate_defaults(config_encode)

config = config_camelyon
config = config_encode
config = config_bdd100k


In [4]:
argstr_camelyon = "--dataset camelyon17 --algorithm ERM --root_dir data"
# argstr_camelyon = "--dataset civilcomments --algorithm ERM --root_dir data"
config_camelyon = parser.parse_args(argstr_camelyon.split())

argstr_encode = "--dataset encode-tfbs --algorithm ERM --root_dir data"
config_encode = parser.parse_args(argstr_encode.split())
config_encode

Namespace(algo_log_metric=None, algorithm='ERM', batch_size=None, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function=None, lr=None, max_grad_norm=None, max_token_length=None, model=None, model_kwargs={'pretrained': False}, n_epochs=None, n_groups_per_batch=None, no_group_logging=None, optimizer=None, optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme=None, target_resolution=None, train_

In [5]:
config#.optimizer_kwargs = {}

Namespace(algo_log_metric='multitask_accuracy', algorithm='ERM', batch_size=32, coral_penalty_weight=None, dataset='bdd100k', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform='image_base', evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='resnet50', model_kwargs={'pretrained': False}, n_epochs=10, n_groups_per_batch=4, no_group_logging=True, optimizer='SGD', optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='off

In [6]:
# set device
config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu")

## Initialize logs
if os.path.exists(config.log_dir) and config.resume:
    resume=True
    mode='a'
elif os.path.exists(config.log_dir) and config.eval_only:
    resume=False
    mode='a'
else:
    resume=False
    mode='w'

if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)
logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

# Record config
log_config(config, logger)

# Set random seed
set_seed(config.seed)

# Data
full_dataset = supported.datasets[config.dataset](
    root_dir=config.root_dir,
    download=config.download,
    split_scheme=config.split_scheme,
    **config.dataset_kwargs)

# To implement data augmentation (i.e., have different transforms
# at training time vs. test time), modify these two lines:
train_transform = initialize_transform(
    transform_name=config.train_transform,
    config=config,
    dataset=full_dataset)
eval_transform = initialize_transform(
    transform_name=config.eval_transform,
    config=config,
    dataset=full_dataset)

Dataset: bdd100k
Algorithm: ERM
Root dir: data
Split scheme: official
Dataset kwargs: {}
Download: False
Frac: 1.0
Loader kwargs: {'num_workers': 1, 'pin_memory': True}
Train loader: standard
Uniform over groups: False
Distinct groups: None
N groups per batch: 4
Batch size: 32
Eval loader: standard
Model: resnet50
Model kwargs: {'pretrained': False}
Train transform: image_base
Eval transform: image_base
Target resolution: (224, 224)
Resize scale: None
Max token length: None
Loss function: multitask_bce
Groupby fields: None
Group dro step size: None
Coral penalty weight: None
Irm lambda: None
Irm penalty anneal iters: None
Algo log metric: multitask_accuracy
Val metric: acc_all
Val metric decreasing: False
N epochs: 10
Optimizer: SGD
Lr: 0.001
Weight decay: 0.0001
Max grad norm: None
Optimizer kwargs: {'momentum': 0.9}
Scheduler: None
Scheduler kwargs: {}
Scheduler metric split: val
Scheduler metric name: None
Evaluate all splits: True
Eval splits: []
Eval only: False
Eval epoch: None
D

## 2) Initialize dataset object (trial version)

In [7]:
import os, time
import torch
import pandas as pd
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy

root_dir='data'
download=False
split_scheme='official'

itime = time.time()
_dataset_name = 'encode-tfbs'
_version = '1.0'
_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'
_data_dir = 'data/encode-tfbs_v1.0/'
_y_size = 1
_n_classes = 2

_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']
_val_chroms = ['chr2', 'chr9', 'chr11']
_test_chroms = ['chr1', 'chr8', 'chr21']
_transcription_factor = 'MAX'
_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']
_val_celltype = ['A549']
_test_celltype = ['GM12878']
_all_chroms = _train_chroms + _val_chroms + _test_chroms
_all_celltypes = _train_celltypes + _val_celltype + _test_celltype

_metadata_map = {}
_metadata_map['chr'] = _all_chroms
_metadata_map['celltype'] = _all_celltypes

# Get the splits
if split_scheme=='official':
    split_scheme = 'standard'

_split_scheme = split_scheme
_split_dict = {
    'train': 0,
    'id_val': 1,
    'test': 2,
    'val': 3
}
_split_names = {
    'train': 'Train',
    'id_val': 'Validation (ID)',
    'test': 'Test',
    'val': 'Validation (OOD)',
}

# Load sequence and DNase features
sequence_filename = os.path.join(_data_dir, 'sequence.npz')
seq_arr = np.load(sequence_filename)
_seq_bp = {}
for chrom in _all_chroms:
    _seq_bp[chrom] = seq_arr[chrom]
    print(chrom, time.time() - itime)

_dnase_allcelltypes = {}
ct = 'avg'
dnase_avg_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))
_dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)
for ct in _all_celltypes:
    """
    dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))
    dnase_npz_contents = np.load(dnase_filename)
    self._dnase_allcelltypes[ct] = {}
    for chrom in self._all_chroms: #self._seq_bp:
        self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]
    """
    dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))
    _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)
    print(ct, time.time() - itime)

_metadata_df = pd.read_csv(
    _data_dir + 'labels/MAX/metadata_df.bed', sep='\t', header=None, 
    index_col=None, names=['chr', 'start', 'stop', 'celltype']
)

chr3 3.0039219856262207
chr4 5.89985990524292
chr5 8.640583038330078
chr6 11.237342596054077
chr7 13.666043519973755
chr10 15.858035326004028
chr12 17.94972252845764
chr13 19.689449071884155
chr14 21.30842876434326
chr15 22.856398582458496


KeyboardInterrupt: 

In [None]:
train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)
val_regions_mask = np.isin(_metadata_df['chr'], _val_chroms)
test_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)
train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)
val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)
test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)

split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)
split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']
split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = _split_dict['test']
# Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')
split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']
split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']

if _split_scheme=='standard':
    _metadata_df.insert(len(_metadata_df.columns), 'split', split_array)
else:
    raise ValueError(f'Split scheme {_split_scheme} not recognized')

metadata_mask = (_metadata_df['split'] != -1)
_metadata_df = _metadata_df[_metadata_df['split'] != -1]

chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values
celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values
_split_array = _metadata_df['split'].values

_y_array = torch.Tensor(np.load(_data_dir + 'labels/MAX/metadata_y.npy'))
_y_array = _y_array[metadata_mask]

_metadata_array = torch.stack(
    (torch.LongTensor(chr_ints), 
     torch.LongTensor(celltype_ints)
    ),
    dim=1)
_metadata_fields = ['chr', 'celltype']

In [325]:
def get_random_label_vec(
    metadata_df, seed_chr, seed_celltype, seed_start, output_size=128
):
    """
    Given a coordinate in a celltype, gets the labels of 
    the `output_size` 200bp bins from that coordinate onward. 
    """
    itime = time.time()
    
    # Extract regions from this chromosome in this celltype, to get a window of labels from
    # print(time.time() - itime)
    # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']
    # print(time.time() - itime)
    # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']
    # mdf = metadata_df[chr_msk & ct_msk]
    seq_size = output_size*50
    mdf = metadata_df.loc[
        (metadata_df['chr'] == seed_chr) & 
        (metadata_df['celltype'] == seed_celltype) & 
        (metadata_df['start'] >= seed_start) & 
        (metadata_df['stop'] < seed_start+seq_size)
    ]
    print(time.time() - itime)

    # Get labels
    y_label_vec = np.zeros(output_size)
    y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']
    return mdf, y_label_vec

# Dataset object (long version)

In [24]:
import os, time
import torch
import pandas as pd
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy

class EncodeTFBSDataset(WILDSDataset):
    """
    ENCODE-DREAM-wilds dataset of transcription factor binding sites. 
    This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. 
    
    Input (x):
        1000-base-pair regions of sequence with a quantified chromatin accessibility readout.

    Label (y):
        y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise.

    Metadata:
        Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string).
    
    Website:
        https://www.synapse.org/#!Synapse:syn6131484
    """

    def __init__(self, root_dir='data', download=False, split_scheme='official'):
        itime = time.time()
        self._dataset_name = 'encode-tfbs'
        self._version = '1.0'
        self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'
        self._data_dir = self.initialize_data_dir(root_dir, download)
        self._y_size = 128
        # self._n_classes = 2
        
        self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']
        self._val_chroms = ['chr2', 'chr9', 'chr11']
        self._test_chroms = ['chr1', 'chr8', 'chr21']
        self._transcription_factor = 'MAX'
        self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']
        self._val_celltype = ['A549']
        self._test_celltype = ['GM12878']
        self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms
        self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype
        
        self._metadata_map = {}
        self._metadata_map['chr'] = self._all_chroms
        self._metadata_map['celltype'] = self._all_celltypes
        
        # Get the splits
        if split_scheme=='official':
            split_scheme = 'standard'
        
        self._split_scheme = split_scheme
        self._split_dict = {
            'train': 0,
            'id_val': 1,
            'test': 2,
            'val': 3
        }
        self._split_names = {
            'train': 'Train',
            'id_val': 'Validation (ID)',
            'test': 'Test',
            'val': 'Validation (OOD)',
        }
        
        # Load sequence and DNase features
        sequence_filename = os.path.join(self._data_dir, 'sequence.npz')
        seq_arr = np.load(sequence_filename)
        self._seq_bp = {}
        for chrom in self._all_chroms: #seq_arr:
            self._seq_bp[chrom] = seq_arr[chrom]
            print(chrom, time.time() - itime)
        
        self._dnase_allcelltypes = {}
        ct = 'avg'
        dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))
        self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)
        for ct in self._all_celltypes:
            """
            dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))
            dnase_npz_contents = np.load(dnase_filename)
            self._dnase_allcelltypes[ct] = {}
            for chrom in self._all_chroms: #self._seq_bp:
                self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]
            """
            dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))
            self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)
        
        self._metadata_df = pd.read_csv(
            self._data_dir + '/labels/MAX/metadata_df.bed', sep='\t', header=None, 
            index_col=None, names=['chr', 'start', 'stop', 'celltype']
        )
        
        train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)
        val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms)
        test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)
        train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)
        val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)
        test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)
        
        split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)
        split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']
        split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test']
        # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')
        split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']
        split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']
        
        if self._split_scheme=='standard':
            self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array)
        else:
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')
        
        metadata_mask = (self._metadata_df['split'] != -1)
        self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]
        
        chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values
        celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values
        self._split_array = self._metadata_df['split'].values
        self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy'))
        self._y_array = self._y_array[metadata_mask]
        
        self._metadata_array = torch.stack(
            (torch.LongTensor(chr_ints), 
             torch.LongTensor(celltype_ints)
            ),
            dim=1)
        self._metadata_fields = ['chr', 'celltype']
        
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=['celltype'])
        
        self._metric = Accuracy()
        
        super().__init__(root_dir, download, split_scheme)
    
    """
    def get_random_label_vec(metadata_df, output_size=128):
        # Sample a positively labeled region at random
        pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]
        pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]

        # Extract regions from this chromosome in this celltype, to get a window of labels from
        chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']
        ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']
        mdf = metadata_df[chr_msk & ct_msk]

        # Get labels
        start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]
        y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']
    """
    
    def get_input(self, idx, window_size=12800):
        """
        Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.
        Computes this from: 
        (1) sequence features in self._seq_bp
        (2) DNase bigwig file handles in self._dnase_allcelltypes
        (3) Metadata for the index (location along the genome with 6400bp window width)
        (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3))
        """
        this_metadata = self._metadata_df.iloc[idx, :]
        interval_start = this_metadata['start'] - int(window_size/4)
        interval_end = interval_start + window_size  #this_metadata['stop']
        seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]
        dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]
        dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)
        dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True)
        return torch.tensor(np.column_stack(
            [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)]
        ))

    def eval(self, y_pred, y_true, metadata):
        return self.standard_group_eval(
            self._metric,
            self._eval_grouper,
            y_pred, y_true, metadata)

In [26]:
full_dataset_encode = EncodeTFBSDataset(
    root_dir=config.root_dir,
    download=config.download,
    split_scheme=config.split_scheme,
    **config.dataset_kwargs)

chr3 3.0425407886505127
chr4 5.967821359634399
chr5 8.747126340866089
chr6 11.370141744613647
chr7 13.802208423614502
chr10 15.875979900360107
chr12 17.929850339889526
chr13 19.67976665496826
chr14 21.306750059127808
chr15 22.866544723510742
chr16 24.241100788116455
chr17 25.480982303619385
chr18 26.677065134048462
chr19 27.579110622406006
chr20 28.545915603637695
chr22 29.323810577392578
chrX 31.698036670684814
chr2 35.40705943107605
chr9 37.5518524646759
chr11 39.61783218383789
chr1 43.411964893341064
chr8 45.64823389053345
chr21 46.377281188964844


# Initialize algorithm

In [7]:
# config = config_encode

train_grouper = CombinatorialGrouper(
    dataset=full_dataset,
    groupby_fields=config.groupby_fields)

datasets = defaultdict(dict)
for split in full_dataset.split_dict.keys():
    if split=='train':
        transform = train_transform
        verbose = True
    elif split == 'val':
        transform = eval_transform
        verbose = True
    else:
        transform = eval_transform
        verbose = False
    # Get subset
    datasets[split]['dataset'] = full_dataset.get_subset(
        split,
        frac=config.frac,
        transform=transform)

    if split == 'train':
        datasets[split]['loader'] = get_train_loader(
            loader=config.train_loader,
            dataset=datasets[split]['dataset'],
            batch_size=config.batch_size,
            uniform_over_groups=config.uniform_over_groups,
            grouper=train_grouper,
            distinct_groups=config.distinct_groups,
            n_groups_per_batch=config.n_groups_per_batch,
            **config.loader_kwargs)
    else:
        datasets[split]['loader'] = get_eval_loader(
            loader=config.eval_loader,
            dataset=datasets[split]['dataset'],
            grouper=train_grouper,
            batch_size=config.batch_size,
            **config.loader_kwargs)

    # Set fields
    datasets[split]['split'] = split
    datasets[split]['name'] = full_dataset.split_names[split]
    datasets[split]['verbose'] = verbose
    # Loggers
    # Loggers
    datasets[split]['eval_logger'] = BatchLogger(
        os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
    datasets[split]['algo_logger'] = BatchLogger(
        os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))

    if config.use_wandb:
        initialize_wandb(config)

# Logging dataset info
if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:
    log_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=['y'])
elif config.no_group_logging:
    log_grouper = None
else:
    log_grouper = train_grouper
log_group_data(datasets, log_grouper, logger)

## Initialize algorithm
algorithm = initialize_algorithm(
    config=config,
    datasets=datasets,
    train_grouper=train_grouper)

Train data...
    n = 64993
Validation data...
    n = 4860
Test data...
    n = 4742
Dout: 9


In [8]:
for batch in datasets['train']['loader']:
    x, y_true, metadata = batch
    break
# x = torch.transpose(x, 1, 2)

In [9]:
d = algorithm.process_batch(batch)

a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)
a

tensor(0.8208, device='cuda:0', grad_fn=<MeanBackward0>)

In [10]:
#np.unique(full_dataset._metadata_df['split'], return_counts=True)
full_dataset

<wilds.datasets.bdd100k_dataset.BDD100KDataset at 0x7fd694397e80>

In [11]:
#import importlib
importlib.reload(train)

NameError: name 'importlib' is not defined

# Train

In [12]:
if not config.eval_only:
    ## Load saved results if resuming
    resume_success = False
    if resume:
        save_path = os.path.join(config.log_dir, 'last_model.pth')
        if not os.path.exists(save_path):
            epochs = [
                int(file.split('_')[0])
                for file in os.listdir(config.log_dir) if file.endswith('.pth')]
            if len(epochs) > 0:
                latest_epoch = max(epochs)
                save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')
        try:
            prev_epoch, best_val_metric = load(algorithm, save_path)
            epoch_offset = prev_epoch + 1
            logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')
            resume_success = True
        except FileNotFoundError:
            pass

    if resume_success == False:
        epoch_offset=0
        best_val_metric=None
    
    train(
        algorithm=algorithm,
        datasets=datasets,
        general_logger=logger,
        config=config,
        epoch_offset=epoch_offset,
        best_val_metric=best_val_metric)
else:
    if config.eval_epoch is None:
        eval_model_path = os.path.join(config.log_dir, 'best_model.pth')
    else:
        eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')
    best_epoch, best_val_metric = load(algorithm, eval_model_path)
    if config.eval_epoch is None:
        epoch = best_epoch
    else:
        epoch = config.eval_epoch
    evaluate(
        algorithm=algorithm,
        datasets=datasets,
        epoch=epoch,
        general_logger=logger,
        config=config)

logger.close()
for split in datasets:
    datasets[split]['eval_logger'].close()
    datasets[split]['algo_logger'].close()


Epoch [0]:

Train:


KeyboardInterrupt: 